Skip to content

Commit 19672e1

Browse files
committed
Fix the decimal division kernel dispatching
1 parent dadc21f commit 19672e1

File tree

8 files changed

+154
-97
lines changed

8 files changed

+154
-97
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -634,19 +634,72 @@ TEST(Expression, BindWithAliasCasts) {
634634
}
635635

636636
TEST(Expression, BindWithDecimalArithmeticOps) {
637-
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
638-
auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")});
639-
EXPECT_FALSE(expr.IsBound());
640-
641-
static const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
642-
for (auto s : scales) {
643-
auto schema = arrow::schema(
644-
{field("d1", decimal256(30, s.first)), field("d2", decimal256(20, s.second))});
645-
ExpectBindsTo(expr, no_change, &expr, *schema);
637+
static const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
638+
639+
for (std::string suffix : {"", "_checked"}) {
640+
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
641+
std::string name = arith_op + suffix;
642+
SCOPED_TRACE(name);
643+
644+
for (auto s : scales) {
645+
auto schema = arrow::schema({field("d1", decimal256(30, s.first)),
646+
field("d2", decimal256(20, s.second))});
647+
auto expr = call(name, {field_ref("d1"), field_ref("d2")});
648+
EXPECT_FALSE(expr.IsBound());
649+
ExpectBindsTo(expr, no_change, &expr, *schema);
650+
}
646651
}
647652
}
648653
}
649654

655+
TEST(Expression, BindWithDecimalDivision) {
656+
auto expect_decimal_division_type = [](std::string name,
657+
std::shared_ptr<DataType> dividend,
658+
std::shared_ptr<DataType> divisor,
659+
std::shared_ptr<DataType> expected) {
660+
auto schema = arrow::schema({field("dividend", dividend), field("divisor", divisor)});
661+
auto expr = call(name, {field_ref("dividend"), field_ref("divisor")});
662+
ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*schema));
663+
EXPECT_TRUE(bound.IsBound());
664+
EXPECT_TRUE(bound.type()->Equals(expected));
665+
};
666+
667+
for (std::string name : {"divide", "divide_checked"}) {
668+
SCOPED_TRACE(name);
669+
670+
expect_decimal_division_type(name, int64(), arrow::decimal128(1, 0),
671+
decimal128(23, 4));
672+
expect_decimal_division_type(name, arrow::decimal128(1, 0), int64(),
673+
decimal128(21, 20));
674+
675+
expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 1),
676+
decimal128(6, 4));
677+
expect_decimal_division_type(name, decimal256(2, 1), decimal256(2, 1),
678+
decimal256(6, 4));
679+
expect_decimal_division_type(name, decimal128(2, 1), decimal256(2, 1),
680+
decimal256(6, 4));
681+
expect_decimal_division_type(name, decimal256(2, 1), decimal128(2, 1),
682+
decimal256(6, 4));
683+
684+
expect_decimal_division_type(name, decimal128(2, 0), decimal128(2, 1),
685+
decimal128(7, 4));
686+
expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 0),
687+
decimal128(5, 4));
688+
689+
// GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type.
690+
// decimal128(3, 2) / decimal128(15, 2)
691+
// -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16)
692+
expect_decimal_division_type(name, decimal128(3, 2), decimal128(15, 2),
693+
decimal128(19, 16));
694+
695+
// GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type.
696+
// decimal128(7, 2) / decimal128(6, 1)
697+
// -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8)
698+
expect_decimal_division_type(name, decimal128(7, 2), decimal128(6, 1),
699+
decimal128(14, 8));
700+
}
701+
}
702+
650703
TEST(Expression, BindWithImplicitCasts) {
651704
for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) {
652705
// cast arguments to common numeric type

cpp/src/arrow/compute/kernel.cc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -519,30 +519,6 @@ std::shared_ptr<MatchConstraint> DecimalsHaveSameScale() {
519519
return instance;
520520
}
521521

522-
namespace {
523-
524-
template <typename Op>
525-
class BinaryDecimalScaleComparisonConstraint : public MatchConstraint {
526-
public:
527-
bool Matches(const std::vector<TypeHolder>& types) const override {
528-
DCHECK_EQ(types.size(), 2);
529-
DCHECK(is_decimal(types[0].id()));
530-
DCHECK(is_decimal(types[1].id()));
531-
const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
532-
const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
533-
return Op{}(ty0.scale(), ty1.scale());
534-
}
535-
};
536-
537-
} // namespace
538-
539-
std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2() {
540-
using BinaryDecimalScale1GeScale2Constraint =
541-
BinaryDecimalScaleComparisonConstraint<std::greater_equal<>>;
542-
static auto instance = std::make_shared<BinaryDecimalScale1GeScale2Constraint>();
543-
return instance;
544-
}
545-
546522
// ----------------------------------------------------------------------
547523
// KernelSignature
548524

cpp/src/arrow/compute/kernel.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,6 @@ ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
365365
/// \brief Constraint that all input types are decimal types and have the same scale.
366366
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
367367

368-
/// \brief Constraint that all binary input types are decimal types and the first type's
369-
/// scale >= the second type's.
370-
ARROW_EXPORT std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2();
371-
372368
/// \brief Holds the input types, optional match constraint and output type of the kernel.
373369
///
374370
/// VarArgs functions with minimum N arguments should pass up to N input types to be

cpp/src/arrow/compute/kernel_test.cc

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -341,23 +341,6 @@ TEST(MatchConstraint, DecimalsHaveSameScale) {
341341
decimal128(precision, scale + 1)}));
342342
}
343343

344-
TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) {
345-
auto c = BinaryDecimalScale1GeScale2();
346-
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
347-
ASSERT_TRUE(
348-
c->Matches({decimal128(precision, big_scale), decimal128(precision, small_scale)}));
349-
ASSERT_TRUE(
350-
c->Matches({decimal128(precision, big_scale), decimal256(precision, small_scale)}));
351-
ASSERT_TRUE(
352-
c->Matches({decimal256(precision, big_scale), decimal128(precision, small_scale)}));
353-
ASSERT_TRUE(
354-
c->Matches({decimal256(precision, big_scale), decimal256(precision, small_scale)}));
355-
ASSERT_TRUE(c->Matches(
356-
{decimal128(precision, small_scale), decimal128(precision, small_scale)}));
357-
ASSERT_FALSE(
358-
c->Matches({decimal128(precision, small_scale), decimal128(precision, big_scale)}));
359-
}
360-
361344
// ----------------------------------------------------------------------
362345
// KernelSignature
363346

@@ -476,26 +459,27 @@ TEST(KernelSignature, MatchesInputsWithConstraint) {
476459
auto small_scale_decimal = decimal128(precision, small_scale);
477460
auto big_scale_decimal = decimal128(precision, big_scale);
478461

479-
// No constraint.
480-
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, boolean());
481-
ASSERT_TRUE(
482-
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
483-
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
484-
ASSERT_TRUE(
485-
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
486-
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
462+
{
463+
// No constraint.
464+
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, boolean());
465+
ASSERT_TRUE(
466+
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
467+
ASSERT_TRUE(
468+
sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
469+
ASSERT_TRUE(
470+
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
471+
ASSERT_TRUE(
472+
sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
473+
}
487474

488-
for (auto constraint : {DecimalsHaveSameScale(), BinaryDecimalScale1GeScale2()}) {
475+
{
476+
// All decimal types must have the same scale.
489477
KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
490-
/*is_varargs=*/false, constraint);
491-
ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}),
492-
sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
493-
ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}),
494-
sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
495-
ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}),
496-
sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
497-
ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}),
498-
sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
478+
/*is_varargs=*/false, DecimalsHaveSameScale());
479+
ASSERT_TRUE(sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
480+
ASSERT_FALSE(sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
481+
ASSERT_FALSE(sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
482+
ASSERT_TRUE(sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
499483
}
500484
}
501485

cpp/src/arrow/compute/kernels/scalar_arithmetic.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,6 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
670670
out_type = OutputType(ResolveDecimalMultiplicationOutput);
671671
} else if (op == "divide") {
672672
out_type = OutputType(ResolveDecimalDivisionOutput);
673-
constraint = BinaryDecimalScale1GeScale2();
674673
} else {
675674
DCHECK(false);
676675
}
@@ -727,6 +726,17 @@ ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id)
727726
struct ArithmeticFunction : ScalarFunction {
728727
using ScalarFunction::ScalarFunction;
729728

729+
Result<const Kernel*> DispatchExact(
730+
const std::vector<TypeHolder>& types) const override {
731+
if ((name_ == "divide" || name_ == "divide_checked") && HasDecimal(types)) {
732+
// Decimal division ALWAYS scales up the dividend, so there will NEVER be an exact
733+
// match.
734+
return arrow::compute::detail::NoMatchingKernel(this, types);
735+
}
736+
737+
return ScalarFunction::DispatchExact(types);
738+
}
739+
730740
Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override {
731741
RETURN_NOT_OK(CheckArity(types->size()));
732742

cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,14 +1942,14 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchExact) {
19421942
name += suffix;
19431943
ARROW_SCOPED_TRACE(name);
19441944

1945-
CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
1946-
CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
1947-
CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
1945+
CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 1)});
1946+
CheckDispatchExactFails(name, {decimal128(3, 1), decimal128(2, 1)});
1947+
CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 0)});
19481948
CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});
19491949

1950-
CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
1951-
CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
1952-
CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
1950+
CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 1)});
1951+
CheckDispatchExactFails(name, {decimal256(3, 1), decimal256(2, 1)});
1952+
CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 0)});
19531953
CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
19541954
}
19551955
}
@@ -2025,24 +2025,36 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchBest) {
20252025
name += suffix;
20262026
SCOPED_TRACE(name);
20272027

2028-
CheckDispatchBest(name, {int64(), decimal128(1, 0)},
2029-
{decimal128(23, 4), decimal128(1, 0)});
2030-
CheckDispatchBest(name, {decimal128(1, 0), int64()},
2031-
{decimal128(21, 20), decimal128(19, 0)});
2032-
2033-
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
2034-
{decimal128(6, 5), decimal128(2, 1)});
2035-
CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
2036-
{decimal256(6, 5), decimal256(2, 1)});
2037-
CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
2038-
{decimal256(6, 5), decimal256(2, 1)});
2039-
CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
2040-
{decimal256(6, 5), decimal256(2, 1)});
2041-
2042-
CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
2043-
{decimal128(7, 5), decimal128(2, 1)});
2044-
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
2045-
{decimal128(5, 4), decimal128(2, 0)});
2028+
CheckDispatchBestWithCastedTypes(name, {int64(), decimal128(1, 0)},
2029+
{decimal128(23, 4), decimal128(1, 0)});
2030+
CheckDispatchBestWithCastedTypes(name, {decimal128(1, 0), int64()},
2031+
{decimal128(21, 20), decimal128(19, 0)});
2032+
2033+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 1)},
2034+
{decimal128(6, 5), decimal128(2, 1)});
2035+
CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal256(2, 1)},
2036+
{decimal256(6, 5), decimal256(2, 1)});
2037+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal256(2, 1)},
2038+
{decimal256(6, 5), decimal256(2, 1)});
2039+
CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal128(2, 1)},
2040+
{decimal256(6, 5), decimal256(2, 1)});
2041+
2042+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 0), decimal128(2, 1)},
2043+
{decimal128(7, 5), decimal128(2, 1)});
2044+
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 0)},
2045+
{decimal128(5, 4), decimal128(2, 0)});
2046+
2047+
// GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type.
2048+
// decimal128(3, 2) / decimal128(15, 2)
2049+
// -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16)
2050+
CheckDispatchBestWithCastedTypes(name, {decimal128(3, 2), decimal128(15, 2)},
2051+
{decimal128(19, 18), decimal128(15, 2)});
2052+
2053+
// GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type.
2054+
// decimal128(7, 2) / decimal128(6, 1)
2055+
// -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8)
2056+
CheckDispatchBestWithCastedTypes(name, {decimal128(7, 2), decimal128(6, 1)},
2057+
{decimal128(14, 9), decimal128(6, 1)});
20462058
}
20472059
}
20482060
for (std::string name : {"atan2", "logb", "logb_checked", "power", "power_checked"}) {
@@ -2332,6 +2344,14 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) {
23322344
CheckScalarBinary("divide", left, right, expected);
23332345
}
23342346

2347+
// decimal(p1, s1) decimal(p2, s2) where s1 < s2
2348+
{
2349+
auto left = ScalarFromJSON(decimal128(6, 5), R"("2.71828")");
2350+
auto right = ScalarFromJSON(decimal128(7, 6), R"("3.141592")");
2351+
auto expected = ScalarFromJSON(decimal128(14, 7), R"("0.8652555")");
2352+
CheckScalarBinary("divide", left, right, expected);
2353+
}
2354+
23352355
// decimal128 decimal256
23362356
{
23372357
auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")");

cpp/src/arrow/compute/kernels/test_util_internal.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,18 @@ void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> original_v
311311
}
312312
}
313313

314+
void CheckDispatchBestWithCastedTypes(std::string func_name,
315+
std::vector<TypeHolder> values,
316+
const std::vector<TypeHolder>& expected_values) {
317+
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
318+
ASSERT_OK_AND_ASSIGN(auto kernel, function->DispatchBest(&values));
319+
ASSERT_NE(kernel, nullptr);
320+
EXPECT_EQ(values.size(), expected_values.size());
321+
for (size_t i = 0; i < values.size(); i++) {
322+
AssertTypeEqual(*values[i], *expected_values[i]);
323+
}
324+
}
325+
314326
void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> types) {
315327
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
316328
ASSERT_NOT_OK(function->DispatchExact(types));

cpp/src/arrow/compute/kernels/test_util_internal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types);
163163
void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> types,
164164
std::vector<TypeHolder> exact_types);
165165

166+
// Check that DispatchBest on a given function yields a valid Kernel and casts the input
167+
// types as expected
168+
void CheckDispatchBestWithCastedTypes(std::string func_name,
169+
std::vector<TypeHolder> types,
170+
const std::vector<TypeHolder>& expected_types);
171+
166172
// Check that function fails to produce a Kernel via DispatchExact for the set of types
167173
void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> types);
168174

0 commit comments

Comments
 (0)