diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 02b5e2a1d920..17c31cfdf3ad 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -634,19 +634,72 @@ TEST(Expression, BindWithAliasCasts) { } TEST(Expression, BindWithDecimalArithmeticOps) { - for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) { - auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")}); - EXPECT_FALSE(expr.IsBound()); - - static const std::vector> scales = {{3, 9}, {6, 6}, {9, 3}}; - for (auto s : scales) { - auto schema = arrow::schema( - {field("d1", decimal256(30, s.first)), field("d2", decimal256(20, s.second))}); - ExpectBindsTo(expr, no_change, &expr, *schema); + static const std::vector> scales = {{3, 9}, {6, 6}, {9, 3}}; + + for (std::string suffix : {"", "_checked"}) { + for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) { + std::string name = arith_op + suffix; + SCOPED_TRACE(name); + + for (auto s : scales) { + auto schema = arrow::schema({field("d1", decimal256(30, s.first)), + field("d2", decimal256(20, s.second))}); + auto expr = call(name, {field_ref("d1"), field_ref("d2")}); + EXPECT_FALSE(expr.IsBound()); + ExpectBindsTo(expr, no_change, &expr, *schema); + } } } } +TEST(Expression, BindWithDecimalDivision) { + auto expect_decimal_division_type = [](std::string name, + std::shared_ptr dividend, + std::shared_ptr divisor, + std::shared_ptr expected) { + auto schema = arrow::schema({field("dividend", dividend), field("divisor", divisor)}); + auto expr = call(name, {field_ref("dividend"), field_ref("divisor")}); + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*schema)); + EXPECT_TRUE(bound.IsBound()); + EXPECT_TRUE(bound.type()->Equals(expected)); + }; + + for (std::string name : {"divide", "divide_checked"}) { + SCOPED_TRACE(name); + + expect_decimal_division_type(name, int64(), arrow::decimal128(1, 0), + decimal128(23, 4)); + expect_decimal_division_type(name, arrow::decimal128(1, 0), int64(), + decimal128(21, 20)); + + expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 1), + decimal128(6, 4)); + expect_decimal_division_type(name, decimal256(2, 1), decimal256(2, 1), + decimal256(6, 4)); + expect_decimal_division_type(name, decimal128(2, 1), decimal256(2, 1), + decimal256(6, 4)); + expect_decimal_division_type(name, decimal256(2, 1), decimal128(2, 1), + decimal256(6, 4)); + + expect_decimal_division_type(name, decimal128(2, 0), decimal128(2, 1), + decimal128(7, 4)); + expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 0), + decimal128(5, 4)); + + // GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type. + // decimal128(3, 2) / decimal128(15, 2) + // -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16) + expect_decimal_division_type(name, decimal128(3, 2), decimal128(15, 2), + decimal128(19, 16)); + + // GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type. + // decimal128(7, 2) / decimal128(6, 1) + // -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8) + expect_decimal_division_type(name, decimal128(7, 2), decimal128(6, 1), + decimal128(14, 8)); + } +} + TEST(Expression, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { // cast arguments to common numeric type diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index f7fecc9247b9..17f583c75fb1 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -519,30 +519,6 @@ std::shared_ptr DecimalsHaveSameScale() { return instance; } -namespace { - -template -class BinaryDecimalScaleComparisonConstraint : public MatchConstraint { - public: - bool Matches(const std::vector& types) const override { - DCHECK_EQ(types.size(), 2); - DCHECK(is_decimal(types[0].id())); - DCHECK(is_decimal(types[1].id())); - const auto& ty0 = checked_cast(*types[0].type); - const auto& ty1 = checked_cast(*types[1].type); - return Op{}(ty0.scale(), ty1.scale()); - } -}; - -} // namespace - -std::shared_ptr BinaryDecimalScale1GeScale2() { - using BinaryDecimalScale1GeScale2Constraint = - BinaryDecimalScaleComparisonConstraint>; - static auto instance = std::make_shared(); - return instance; -} - // ---------------------------------------------------------------------- // KernelSignature diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index fdcdb134de86..fa2e98346957 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -365,10 +365,6 @@ ARROW_EXPORT std::shared_ptr MakeConstraint( /// \brief Constraint that all input types are decimal types and have the same scale. ARROW_EXPORT std::shared_ptr DecimalsHaveSameScale(); -/// \brief Constraint that all binary input types are decimal types and the first type's -/// scale >= the second type's. -ARROW_EXPORT std::shared_ptr BinaryDecimalScale1GeScale2(); - /// \brief Holds the input types, optional match constraint and output type of the kernel. /// /// VarArgs functions with minimum N arguments should pass up to N input types to be diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index deaddaddc63f..374d1458f425 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -341,23 +341,6 @@ TEST(MatchConstraint, DecimalsHaveSameScale) { decimal128(precision, scale + 1)})); } -TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) { - auto c = BinaryDecimalScale1GeScale2(); - constexpr int32_t precision = 12, small_scale = 2, big_scale = 3; - ASSERT_TRUE( - c->Matches({decimal128(precision, big_scale), decimal128(precision, small_scale)})); - ASSERT_TRUE( - c->Matches({decimal128(precision, big_scale), decimal256(precision, small_scale)})); - ASSERT_TRUE( - c->Matches({decimal256(precision, big_scale), decimal128(precision, small_scale)})); - ASSERT_TRUE( - c->Matches({decimal256(precision, big_scale), decimal256(precision, small_scale)})); - ASSERT_TRUE(c->Matches( - {decimal128(precision, small_scale), decimal128(precision, small_scale)})); - ASSERT_FALSE( - c->Matches({decimal128(precision, small_scale), decimal128(precision, big_scale)})); -} - // ---------------------------------------------------------------------- // KernelSignature @@ -471,31 +454,30 @@ TEST(KernelSignature, VarArgsMatchesInputs) { } TEST(KernelSignature, MatchesInputsWithConstraint) { - constexpr int32_t precision = 12, small_scale = 2, big_scale = 3; - - auto small_scale_decimal = decimal128(precision, small_scale); - auto big_scale_decimal = decimal128(precision, big_scale); - - // No constraint. - KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, boolean()); - ASSERT_TRUE( - sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal})); - ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal})); - ASSERT_TRUE( - sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal})); - ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal})); - - for (auto constraint : {DecimalsHaveSameScale(), BinaryDecimalScale1GeScale2()}) { - KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(), - /*is_varargs=*/false, constraint); - ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}), - sig.MatchesInputs({small_scale_decimal, small_scale_decimal})); - ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}), - sig.MatchesInputs({small_scale_decimal, big_scale_decimal})); - ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}), - sig.MatchesInputs({big_scale_decimal, small_scale_decimal})); - ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}), - sig.MatchesInputs({big_scale_decimal, big_scale_decimal})); + auto precisions = {12, 22}, scales = {2, 3}; + for (auto p1 : precisions) { + for (auto s1 : scales) { + auto d1 = decimal128(p1, s1); + for (auto p2 : precisions) { + for (auto s2 : scales) { + auto d2 = decimal128(p2, s2); + + { + // No constraint. + KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, + boolean()); + ASSERT_TRUE(sig_no_constraint.MatchesInputs({d1, d2})); + } + + { + // All decimal types must have the same scale. + KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(), + /*is_varargs=*/false, DecimalsHaveSameScale()); + ASSERT_EQ(sig.MatchesInputs({d1, d2}), s1 == s2); + } + } + } + } } } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index ccbd361362c4..03c9422809b8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -670,7 +670,6 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) { out_type = OutputType(ResolveDecimalMultiplicationOutput); } else if (op == "divide") { out_type = OutputType(ResolveDecimalDivisionOutput); - constraint = BinaryDecimalScale1GeScale2(); } else { DCHECK(false); } @@ -727,6 +726,17 @@ ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id) struct ArithmeticFunction : ScalarFunction { using ScalarFunction::ScalarFunction; + Result DispatchExact( + const std::vector& types) const override { + if ((name_ == "divide" || name_ == "divide_checked") && HasDecimal(types)) { + // Decimal division ALWAYS scales up the dividend, so there will NEVER be an exact + // match. + return arrow::compute::detail::NoMatchingKernel(this, types); + } + + return ScalarFunction::DispatchExact(types); + } + Result DispatchBest(std::vector* types) const override { RETURN_NOT_OK(CheckArity(types->size())); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 0956168fc348..6e9e0620f8f1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1942,14 +1942,14 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchExact) { name += suffix; ARROW_SCOPED_TRACE(name); - CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)}); - CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)}); - CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)}); + CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 1)}); + CheckDispatchExactFails(name, {decimal128(3, 1), decimal128(2, 1)}); + CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 0)}); CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)}); - CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)}); - CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)}); - CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)}); + CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 1)}); + CheckDispatchExactFails(name, {decimal256(3, 1), decimal256(2, 1)}); + CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 0)}); CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)}); } } @@ -2025,24 +2025,36 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchBest) { name += suffix; SCOPED_TRACE(name); - CheckDispatchBest(name, {int64(), decimal128(1, 0)}, - {decimal128(23, 4), decimal128(1, 0)}); - CheckDispatchBest(name, {decimal128(1, 0), int64()}, - {decimal128(21, 20), decimal128(19, 0)}); - - CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, - {decimal128(6, 5), decimal128(2, 1)}); - CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, - {decimal256(6, 5), decimal256(2, 1)}); - CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, - {decimal256(6, 5), decimal256(2, 1)}); - CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, - {decimal256(6, 5), decimal256(2, 1)}); - - CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)}, - {decimal128(7, 5), decimal128(2, 1)}); - CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)}, - {decimal128(5, 4), decimal128(2, 0)}); + CheckDispatchBestWithCastedTypes(name, {int64(), decimal128(1, 0)}, + {decimal128(23, 4), decimal128(1, 0)}); + CheckDispatchBestWithCastedTypes(name, {decimal128(1, 0), int64()}, + {decimal128(21, 20), decimal128(19, 0)}); + + CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(6, 5), decimal128(2, 1)}); + CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(6, 5), decimal256(2, 1)}); + CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(6, 5), decimal256(2, 1)}); + CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(6, 5), decimal256(2, 1)}); + + CheckDispatchBestWithCastedTypes(name, {decimal128(2, 0), decimal128(2, 1)}, + {decimal128(7, 5), decimal128(2, 1)}); + CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 0)}, + {decimal128(5, 4), decimal128(2, 0)}); + + // GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type. + // decimal128(3, 2) / decimal128(15, 2) + // -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16) + CheckDispatchBestWithCastedTypes(name, {decimal128(3, 2), decimal128(15, 2)}, + {decimal128(19, 18), decimal128(15, 2)}); + + // GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type. + // decimal128(7, 2) / decimal128(6, 1) + // -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8) + CheckDispatchBestWithCastedTypes(name, {decimal128(7, 2), decimal128(6, 1)}, + {decimal128(14, 9), decimal128(6, 1)}); } } for (std::string name : {"atan2", "logb", "logb_checked", "power", "power_checked"}) { @@ -2332,6 +2344,14 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) { CheckScalarBinary("divide", left, right, expected); } + // decimal(p1, s1) decimal(p2, s2) where s1 < s2 + { + auto left = ScalarFromJSON(decimal128(6, 5), R"("2.71828")"); + auto right = ScalarFromJSON(decimal128(7, 6), R"("3.141592")"); + auto expected = ScalarFromJSON(decimal128(14, 7), R"("0.8652555")"); + CheckScalarBinary("divide", left, right, expected); + } + // decimal128 decimal256 { auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")"); diff --git a/cpp/src/arrow/compute/kernels/test_util_internal.cc b/cpp/src/arrow/compute/kernels/test_util_internal.cc index d1cad235ccbf..b184fe7d44c4 100644 --- a/cpp/src/arrow/compute/kernels/test_util_internal.cc +++ b/cpp/src/arrow/compute/kernels/test_util_internal.cc @@ -311,6 +311,18 @@ void CheckDispatchBest(std::string func_name, std::vector original_v } } +void CheckDispatchBestWithCastedTypes(std::string func_name, + std::vector values, + const std::vector& expected_values) { + ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name)); + ASSERT_OK_AND_ASSIGN(auto kernel, function->DispatchBest(&values)); + ASSERT_NE(kernel, nullptr); + EXPECT_EQ(values.size(), expected_values.size()); + for (size_t i = 0; i < values.size(); i++) { + AssertTypeEqual(*values[i], *expected_values[i]); + } +} + void CheckDispatchExactFails(std::string func_name, std::vector types) { ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name)); ASSERT_NOT_OK(function->DispatchExact(types)); diff --git a/cpp/src/arrow/compute/kernels/test_util_internal.h b/cpp/src/arrow/compute/kernels/test_util_internal.h index 1077101377c5..231e5762135a 100644 --- a/cpp/src/arrow/compute/kernels/test_util_internal.h +++ b/cpp/src/arrow/compute/kernels/test_util_internal.h @@ -163,6 +163,12 @@ void CheckDispatchExact(std::string func_name, std::vector types); void CheckDispatchBest(std::string func_name, std::vector types, std::vector exact_types); +// Check that DispatchBest on a given function yields a valid Kernel and casts the input +// types as expected +void CheckDispatchBestWithCastedTypes(std::string func_name, + std::vector types, + const std::vector& expected_types); + // Check that function fails to produce a Kernel via DispatchExact for the set of types void CheckDispatchExactFails(std::string func_name, std::vector types);