Skip to content

Commit 2e2aa0b

Browse files
authored
GH-40911: [C++][Compute] Fix the decimal division kernel dispatching (#47445)
### Rationale for this change The issues in #40911 and #39875 are the same: we have a fundamental defect when dispatching kernels for decimal division. The following statements assume both the dividend and the divisor are of the same decimal type (`Decimal32/64/128/256`), with possibly different `(p, s)`. * When doing `DispatchBest`, which is directly invoked through `CallFunction("divide", ...)`, w/o trying `DispatchExact` ahead, the dividend is ALWAYS promoted and the result will have the same `(p, s)` as the dividend, according to the rule listed in our documentation [1] (this is actually adopting the Redshift one [2]). * When doing `DispatchExact`, which is first tried by expression evaluation, there will be a match w/o any promotions so the subsequent try of `DispatchMatch` won't happen. The issue is obvious - `DispatchExact` and `DispatchBest` are conflicting - one saying "OK, for any `decimal128(p1, s1) / decimal128(p2, s2)`, it is a match" and the other saying "No, we must promote the dividend according to `(p1, s1)` and `(p2, s2)`". Then we actually have two choices to fix it: 1. Consider `DispatchBest` is doing the right thing (justified by [1]), and NEVER "exact match" any kernel for decimal division. This is what this PR does. The only problem is that we are basically ALWAYS rejecting a kernel to be "exactly matched" - weird, though functionally correct. 2. Consider `DispatchExact` is doing the right thing, and NOT promoting dividend in `DispatchBest`. The kernel is matched only based on their decimal type (not considering their `(p, s)`). And only the result is promoted (this also complies [1]). This is what the other attempting PR #40969 does. But that PR only claims a promoted result type w/o actually promoting the computation (i.e., the memory representation of a decimal needs to be promoted when doing the division) so the result is wrong. Though this is amendable by supporting basic decimal methods like `PromoteAndDivide` that does the promotion of the dividend and the division all together in one run, the modification can be cumbersome - the "scale up" needs to be propagated from the kernel definition all down to the basic decimal primitives. Besides, I assume this may not be as performant as doing batch promotion + batch division. [1] https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html#r_numeric_computations201-precision-and-scale-of-computed-decimal-results [2] https://arrow.apache.org/docs/cpp/compute.html#arithmetic-functions ### What changes are included in this PR? Suppress the `DispatchExact` for decimal division. Also, the match constraint `BinaryDecimalScale1GeScale2` introduced in #47297 becomes useless thus gets removed. ### Are these changes tested? Yes. ### Are there any user-facing changes? None. * GitHub Issue: #40911 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent ed77d25 commit 2e2aa0b

File tree

8 files changed

+159
-104
lines changed

8 files changed

+159
-104
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -673,19 +673,72 @@ TEST(Expression, BindWithAliasCasts) {
673673
}
674674

675675
TEST(Expression, BindWithDecimalArithmeticOps) {
676-
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
677-
auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")});
678-
EXPECT_FALSE(expr.IsBound());
679-
680-
static const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
681-
for (auto s : scales) {
682-
auto schema = arrow::schema(
683-
{field("d1", decimal256(30, s.first)), field("d2", decimal256(20, s.second))});
684-
ExpectBindsTo(expr, no_change, &expr, *schema);
676+
static const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
677+
678+
for (std::string suffix : {"", "_checked"}) {
679+
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
680+
std::string name = arith_op + suffix;
681+
SCOPED_TRACE(name);
682+
683+
for (auto s : scales) {
684+
auto schema = arrow::schema({field("d1", decimal256(30, s.first)),
685+
field("d2", decimal256(20, s.second))});
686+
auto expr = call(name, {field_ref("d1"), field_ref("d2")});
687+
EXPECT_FALSE(expr.IsBound());
688+
ExpectBindsTo(expr, no_change, &expr, *schema);
689+
}
685690
}
686691
}
687692
}
688693

694+
TEST(Expression, BindWithDecimalDivision) {
695+
auto expect_decimal_division_type = [](std::string name,
696+
std::shared_ptr<DataType> dividend,
697+
std::shared_ptr<DataType> divisor,
698+
std::shared_ptr<DataType> expected) {
699+
auto schema = arrow::schema({field("dividend", dividend), field("divisor", divisor)});
700+
auto expr = call(name, {field_ref("dividend"), field_ref("divisor")});
701+
ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*schema));
702+
EXPECT_TRUE(bound.IsBound());
703+
EXPECT_TRUE(bound.type()->Equals(expected));
704+
};
705+
706+
for (std::string name : {"divide", "divide_checked"}) {
707+
SCOPED_TRACE(name);
708+
709+
expect_decimal_division_type(name, int64(), arrow::decimal128(1, 0),
710+
decimal128(23, 4));
711+
expect_decimal_division_type(name, arrow::decimal128(1, 0), int64(),
712+
decimal128(21, 20));
713+
714+
expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 1),
715+
decimal128(6, 4));
716+
expect_decimal_division_type(name, decimal256(2, 1), decimal256(2, 1),
717+
decimal256(6, 4));
718+
expect_decimal_division_type(name, decimal128(2, 1), decimal256(2, 1),
719+
decimal256(6, 4));
720+
expect_decimal_division_type(name, decimal256(2, 1), decimal128(2, 1),
721+
decimal256(6, 4));
722+
723+
expect_decimal_division_type(name, decimal128(2, 0), decimal128(2, 1),
724+
decimal128(7, 4));
725+
expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 0),
726+
decimal128(5, 4));
727+
728+
// GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type.
729+
// decimal128(3, 2) / decimal128(15, 2)
730+
// -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16)
731+
expect_decimal_division_type(name, decimal128(3, 2), decimal128(15, 2),
732+
decimal128(19, 16));
733+
734+
// GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type.
735+
// decimal128(7, 2) / decimal128(6, 1)
736+
// -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8)
737+
expect_decimal_division_type(name, decimal128(7, 2), decimal128(6, 1),
738+
decimal128(14, 8));
739+
}
740+
}
741+
689742
TEST(Expression, BindWithImplicitCasts) {
690743
for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) {
691744
// cast arguments to common numeric type

cpp/src/arrow/compute/kernel.cc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -519,30 +519,6 @@ std::shared_ptr<MatchConstraint> DecimalsHaveSameScale() {
519519
return instance;
520520
}
521521

522-
namespace {
523-
524-
template <typename Op>
525-
class BinaryDecimalScaleComparisonConstraint : public MatchConstraint {
526-
public:
527-
bool Matches(const std::vector<TypeHolder>& types) const override {
528-
DCHECK_EQ(types.size(), 2);
529-
DCHECK(is_decimal(types[0].id()));
530-
DCHECK(is_decimal(types[1].id()));
531-
const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
532-
const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
533-
return Op{}(ty0.scale(), ty1.scale());
534-
}
535-
};
536-
537-
} // namespace
538-
539-
std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2() {
540-
using BinaryDecimalScale1GeScale2Constraint =
541-
BinaryDecimalScaleComparisonConstraint<std::greater_equal<>>;
542-
static auto instance = std::make_shared<BinaryDecimalScale1GeScale2Constraint>();
543-
return instance;
544-
}
545-
546522
// ----------------------------------------------------------------------
547523
// KernelSignature
548524

cpp/src/arrow/compute/kernel.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,6 @@ ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
365365
/// \brief Constraint that all input types are decimal types and have the same scale.
366366
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
367367

368-
/// \brief Constraint that all binary input types are decimal types and the first type's
369-
/// scale >= the second type's.
370-
ARROW_EXPORT std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2();
371-
372368
/// \brief Holds the input types, optional match constraint and output type of the kernel.
373369
///
374370
/// VarArgs functions with minimum N arguments should pass up to N input types to be

cpp/src/arrow/compute/kernel_test.cc

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -341,23 +341,6 @@ TEST(MatchConstraint, DecimalsHaveSameScale) {
341341
decimal128(precision, scale + 1)}));
342342
}
343343

344-
TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) {
345-
auto c = BinaryDecimalScale1GeScale2();
346-
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
347-
ASSERT_TRUE(
348-
c->Matches({decimal128(precision, big_scale), decimal128(precision, small_scale)}));
349-
ASSERT_TRUE(
350-
c->Matches({decimal128(precision, big_scale), decimal256(precision, small_scale)}));
351-
ASSERT_TRUE(
352-
c->Matches({decimal256(precision, big_scale), decimal128(precision, small_scale)}));
353-
ASSERT_TRUE(
354-
c->Matches({decimal256(precision, big_scale), decimal256(precision, small_scale)}));
355-
ASSERT_TRUE(c->Matches(
356-
{decimal128(precision, small_scale), decimal128(precision, small_scale)}));
357-
ASSERT_FALSE(
358-
c->Matches({decimal128(precision, small_scale), decimal128(precision, big_scale)}));
359-
}
360-
361344
// ----------------------------------------------------------------------
362345
// KernelSignature
363346

@@ -471,31 +454,30 @@ TEST(KernelSignature, VarArgsMatchesInputs) {
471454
}
472455

473456
TEST(KernelSignature, MatchesInputsWithConstraint) {
474-
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
475-
476-
auto small_scale_decimal = decimal128(precision, small_scale);
477-
auto big_scale_decimal = decimal128(precision, big_scale);
478-
479-
// No constraint.
480-
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, boolean());
481-
ASSERT_TRUE(
482-
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
483-
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
484-
ASSERT_TRUE(
485-
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
486-
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
487-
488-
for (auto constraint : {DecimalsHaveSameScale(), BinaryDecimalScale1GeScale2()}) {
489-
KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
490-
/*is_varargs=*/false, constraint);
491-
ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}),
492-
sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
493-
ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}),
494-
sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
495-
ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}),
496-
sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
497-
ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}),
498-
sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
457+
auto precisions = {12, 22}, scales = {2, 3};
458+
for (auto p1 : precisions) {
459+
for (auto s1 : scales) {
460+
auto d1 = decimal128(p1, s1);
461+
for (auto p2 : precisions) {
462+
for (auto s2 : scales) {
463+
auto d2 = decimal128(p2, s2);
464+
465+
{
466+
// No constraint.
467+
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128},
468+
boolean());
469+
ASSERT_TRUE(sig_no_constraint.MatchesInputs({d1, d2}));
470+
}
471+
472+
{
473+
// All decimal types must have the same scale.
474+
KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
475+
/*is_varargs=*/false, DecimalsHaveSameScale());
476+
ASSERT_EQ(sig.MatchesInputs({d1, d2}), s1 == s2);
477+
}
478+
}
479+
}
480+
}
499481
}
500482
}
501483

cpp/src/arrow/compute/kernels/scalar_arithmetic.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,6 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
670670
out_type = OutputType(ResolveDecimalMultiplicationOutput);
671671
} else if (op == "divide") {
672672
out_type = OutputType(ResolveDecimalDivisionOutput);
673-
constraint = BinaryDecimalScale1GeScale2();
674673
} else {
675674
DCHECK(false);
676675
}
@@ -727,6 +726,17 @@ ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id)
727726
struct ArithmeticFunction : ScalarFunction {
728727
using ScalarFunction::ScalarFunction;
729728

729+
Result<const Kernel*> DispatchExact(
730+
const std::vector<TypeHolder>& types) const override {
731+
if ((name_ == "divide" || name_ == "divide_checked") && HasDecimal(types)) {
732+
// Decimal division ALWAYS scales up the dividend, so there will NEVER be an exact
733+
// match.
734+
return arrow::compute::detail::NoMatchingKernel(this, types);
735+
}
736+
737+
return ScalarFunction::DispatchExact(types);
738+
}
739+
730740
Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override {
731741
RETURN_NOT_OK(CheckArity(types->size()));
732742

cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,14 +1942,14 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchExact) {
19421942
name += suffix;
19431943
ARROW_SCOPED_TRACE(name);
19441944

1945-
CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
1946-
CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
1947-
CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
1945+
CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 1)});
1946+
CheckDispatchExactFails(name, {decimal128(3, 1), decimal128(2, 1)});
1947+
CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 0)});
19481948
CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});
19491949

1950-
CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
1951-
CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
1952-
CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
1950+
CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 1)});
1951+
CheckDispatchExactFails(name, {decimal256(3, 1), decimal256(2, 1)});
1952+
CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 0)});
19531953
CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
19541954
}
19551955
}
@@ -2025,24 +2025,36 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchBest) {
20252025
name += suffix;
20262026
SCOPED_TRACE(name);
20272027

2028-
CheckDispatchBest(name, {int64(), decimal128(1, 0)},
2029-
{decimal128(23, 4), decimal128(1, 0)});
2030-
CheckDispatchBest(name, {decimal128(1, 0), int64()},
2031-
{decimal128(21, 20), decimal128(19, 0)});
2032-
2033-
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
2034-
{decimal128(6, 5), decimal128(2, 1)});
2035-
CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
2036-
{decimal256(6, 5), decimal256(2, 1)});
2037-
CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
2038-
{decimal256(6, 5), decimal256(2, 1)});
2039-
CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
2040-
{decimal256(6, 5), decimal256(2, 1)});
2041-
2042-
CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
2043-
{decimal128(7, 5), decimal128(2, 1)});
2044-
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
2045-
{decimal128(5, 4), decimal128(2, 0)});
2028+
CheckDispatchBestWithCastedTypes(name, {int64(), decimal128(1, 0)},
2029+
{decimal128(23, 4), decimal128(1, 0)});
2030+
CheckDispatchBestWithCastedTypes(name, {decimal128(1, 0), int64()},
2031+
{decimal128(21, 20), decimal128(19, 0)});
2032+
2033+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 1)},
2034+
{decimal128(6, 5), decimal128(2, 1)});
2035+
CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal256(2, 1)},
2036+
{decimal256(6, 5), decimal256(2, 1)});
2037+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal256(2, 1)},
2038+
{decimal256(6, 5), decimal256(2, 1)});
2039+
CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal128(2, 1)},
2040+
{decimal256(6, 5), decimal256(2, 1)});
2041+
2042+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 0), decimal128(2, 1)},
2043+
{decimal128(7, 5), decimal128(2, 1)});
2044+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 0)},
2045+
{decimal128(5, 4), decimal128(2, 0)});
2046+
2047+
// GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type.
2048+
// decimal128(3, 2) / decimal128(15, 2)
2049+
// -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16)
2050+
CheckDispatchBestWithCastedTypes(name, {decimal128(3, 2), decimal128(15, 2)},
2051+
{decimal128(19, 18), decimal128(15, 2)});
2052+
2053+
// GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type.
2054+
// decimal128(7, 2) / decimal128(6, 1)
2055+
// -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8)
2056+
CheckDispatchBestWithCastedTypes(name, {decimal128(7, 2), decimal128(6, 1)},
2057+
{decimal128(14, 9), decimal128(6, 1)});
20462058
}
20472059
}
20482060
for (std::string name : {"atan2", "logb", "logb_checked", "power", "power_checked"}) {
@@ -2332,6 +2344,14 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) {
23322344
CheckScalarBinary("divide", left, right, expected);
23332345
}
23342346

2347+
// decimal(p1, s1) decimal(p2, s2) where s1 < s2
2348+
{
2349+
auto left = ScalarFromJSON(decimal128(6, 5), R"("2.71828")");
2350+
auto right = ScalarFromJSON(decimal128(7, 6), R"("3.141592")");
2351+
auto expected = ScalarFromJSON(decimal128(14, 7), R"("0.8652555")");
2352+
CheckScalarBinary("divide", left, right, expected);
2353+
}
2354+
23352355
// decimal128 decimal256
23362356
{
23372357
auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")");

cpp/src/arrow/compute/kernels/test_util_internal.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,18 @@ void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> original_v
311311
}
312312
}
313313

314+
void CheckDispatchBestWithCastedTypes(std::string func_name,
315+
std::vector<TypeHolder> values,
316+
const std::vector<TypeHolder>& expected_values) {
317+
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
318+
ASSERT_OK_AND_ASSIGN(auto kernel, function->DispatchBest(&values));
319+
ASSERT_NE(kernel, nullptr);
320+
EXPECT_EQ(values.size(), expected_values.size());
321+
for (size_t i = 0; i < values.size(); i++) {
322+
AssertTypeEqual(*values[i], *expected_values[i]);
323+
}
324+
}
325+
314326
void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> types) {
315327
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
316328
ASSERT_NOT_OK(function->DispatchExact(types));

cpp/src/arrow/compute/kernels/test_util_internal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types);
163163
void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> types,
164164
std::vector<TypeHolder> exact_types);
165165

166+
// Check that DispatchBest on a given function yields a valid Kernel and casts the input
167+
// types as expected
168+
void CheckDispatchBestWithCastedTypes(std::string func_name,
169+
std::vector<TypeHolder> types,
170+
const std::vector<TypeHolder>& expected_types);
171+
166172
// Check that function fails to produce a Kernel via DispatchExact for the set of types
167173
void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> types);
168174

0 commit comments

Comments
 (0)