diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 02b5e2a1d920..541067aedd0a 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -717,6 +717,51 @@ TEST(Expression, BindWithImplicitCasts) { call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a)); } +TEST(Expression, BindWithImplicitCastsForCaseWhenOnDecimal) { + auto exciting_schema = schema( + {field("a", struct_({field("", boolean())})), + field("dec128_20_3", decimal128(20, 3)), field("dec128_21_3", decimal128(21, 3)), + field("dec128_20_1", decimal128(20, 1)), field("dec128_21_1", decimal128(21, 1)), + field("dec256_20_3", decimal256(20, 3)), field("dec256_21_3", decimal256(21, 3)), + field("dec256_20_1", decimal256(20, 1)), field("dec256_21_1", decimal256(21, 1))}); + ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"), + field_ref("dec128_21_3")}), + call("case_when", + {field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(21, 3)), + field_ref("dec128_21_3")}), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_1"), + field_ref("dec128_21_3")}), + call("case_when", + {field_ref("a"), cast(field_ref("dec128_20_1"), decimal128(22, 3)), + cast(field_ref("dec128_21_3"), decimal128(22, 3))}), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"), + field_ref("dec128_21_1")}), + call("case_when", + {field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(23, 3)), + cast(field_ref("dec128_21_1"), decimal128(23, 3))}), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"), + field_ref("dec256_21_3")}), + call("case_when", + {field_ref("a"), cast(field_ref("dec128_20_3"), decimal256(21, 3)), + field_ref("dec256_21_3")}), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_1"), + field_ref("dec128_21_3")}), + call("case_when", + {field_ref("a"), cast(field_ref("dec256_20_1"), decimal256(22, 3)), + cast(field_ref("dec128_21_3"), decimal256(22, 3))}), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_3"), + field_ref("dec256_21_1")}), + call("case_when", + {field_ref("a"), cast(field_ref("dec256_20_3"), decimal256(23, 3)), + cast(field_ref("dec256_21_1"), decimal256(23, 3))}), + /*bound_out=*/nullptr, *exciting_schema); +} + TEST(Expression, BindNestedCall) { auto expr = add(field_ref("a"), call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}), diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index f7fecc9247b9..90e4aacb37a7 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -478,7 +478,7 @@ std::string OutputType::ToString() const { // ---------------------------------------------------------------------- // MatchConstraint -std::shared_ptr MakeConstraint( +std::shared_ptr MatchConstraint::Make( std::function&)> matches) { class FunctionMatchConstraint : public MatchConstraint { public: diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index fdcdb134de86..06b83c379f6c 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -356,11 +356,11 @@ class ARROW_EXPORT MatchConstraint { /// \brief Return true if the input types satisfy the constraint. virtual bool Matches(const std::vector& types) const = 0; -}; -/// \brief Convenience function to create a MatchConstraint from a match function. -ARROW_EXPORT std::shared_ptr MakeConstraint( - std::function&)> matches); + /// \brief Convenience function to create a MatchConstraint from a match function. + static std::shared_ptr Make( + std::function&)> matches); +}; /// \brief Constraint that all input types are decimal types and have the same scale. ARROW_EXPORT std::shared_ptr DecimalsHaveSameScale(); diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index deaddaddc63f..3855695d417d 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -313,7 +313,7 @@ TEST(OutputType, Resolve) { TEST(MatchConstraint, ConvenienceMaker) { { auto always_match = - MakeConstraint([](const std::vector& types) { return true; }); + MatchConstraint::Make([](const std::vector& types) { return true; }); ASSERT_TRUE(always_match->Matches({})); ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()})); @@ -321,7 +321,7 @@ TEST(MatchConstraint, ConvenienceMaker) { { auto always_false = - MakeConstraint([](const std::vector& types) { return false; }); + MatchConstraint::Make([](const std::vector& types) { return false; }); ASSERT_FALSE(always_false->Matches({})); ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()})); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 753cc4de9f32..d885db4cd936 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1451,6 +1451,20 @@ struct CaseWhenFunction : ScalarFunction { if (auto kernel = DispatchExactImpl(this, *types)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *types); } + + static std::shared_ptr DecimalMatchConstraint() { + static auto constraint = + MatchConstraint::Make([](const std::vector& types) -> bool { + DCHECK_GE(types.size(), 2); + DCHECK(std::all_of(types.begin() + 1, types.end(), [](const TypeHolder& type) { + return is_decimal(type.id()); + })); + return std::all_of( + types.begin() + 2, types.end(), + [&types](const TypeHolder& type) { return type == types[1]; }); + }); + return constraint; + } }; // Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions @@ -2712,10 +2726,11 @@ struct ChooseFunction : ScalarFunction { }; void AddCaseWhenKernel(const std::shared_ptr& scalar_function, - detail::GetTypeId get_id, ArrayKernelExec exec) { + detail::GetTypeId get_id, ArrayKernelExec exec, + std::shared_ptr constraint = nullptr) { ScalarKernel kernel( KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType, - /*is_varargs=*/true), + /*is_varargs=*/true, std::move(constraint)), exec); if (is_fixed_width(get_id.id)) { kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; @@ -2890,8 +2905,10 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveCaseWhenKernels(func, {boolean(), null(), float16()}); AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY, CaseWhenFunctor::Exec); - AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec); - AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec, + CaseWhenFunction::DecimalMatchConstraint()); + AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec, + CaseWhenFunction::DecimalMatchConstraint()); AddBinaryCaseWhenKernels(func, BaseBinaryTypes()); AddNestedCaseWhenKernels(func); DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 2ff11dab430f..e007a16d13b8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1807,6 +1807,74 @@ TEST(TestCaseWhen, Decimal) { } } +TEST(TestCaseWhen, DecimalPromotion) { + auto check_case_when_decimal_promotion = + [](std::shared_ptr body_true, std::shared_ptr body_false, + std::shared_ptr promoted_true, std::shared_ptr promoted_false) { + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + CheckScalar("case_when", {MakeStruct({cond_true}), body_true, body_false}, + promoted_true); + CheckScalar("case_when", {MakeStruct({cond_false}), body_true, body_false}, + promoted_false); + }; + + const std::vector> precisions = {{10, 20}, {15, 15}, {20, 10}}; + const std::vector> scales = {{3, 9}, {6, 6}, {9, 3}}; + for (auto p : precisions) { + for (auto s : scales) { + auto p1 = p.first; + auto s1 = s.first; + auto p2 = p.second; + auto s2 = s.second; + + auto max_scale = std::max({s1, s2}); + auto scale_up_1 = max_scale - s1; + auto scale_up_2 = max_scale - s2; + auto max_precision = std::max({p1 + scale_up_1, p2 + scale_up_2}); + + // Operand string: 444.777... + std::string str_d1 = + R"(")" + std::string(p1 - s1, '4') + "." + std::string(s1, '7') + R"(")"; + std::string str_d2 = + R"(")" + std::string(p2 - s2, '4') + "." + std::string(s2, '7') + R"(")"; + + // Promoted string: 444.777...000 + std::string str_d1_promoted = R"(")" + std::string(p1 - s1, '4') + "." + + std::string(s1, '7') + + std::string(max_scale - s1, '0') + R"(")"; + std::string str_d2_promoted = R"(")" + std::string(p2 - s2, '4') + "." + + std::string(s2, '7') + + std::string(max_scale - s2, '0') + R"(")"; + + auto d128_1 = decimal128(p1, s1); + auto d128_2 = decimal128(p2, s2); + auto d256_1 = decimal256(p1, s1); + auto d256_2 = decimal256(p2, s2); + auto d128_promoted = decimal128(max_precision, max_scale); + auto d256_promoted = decimal256(max_precision, max_scale); + + auto scalar128_1 = ScalarFromJSON(d128_1, str_d1); + auto scalar128_2 = ScalarFromJSON(d128_2, str_d2); + auto scalar256_1 = ScalarFromJSON(d256_1, str_d1); + auto scalar256_2 = ScalarFromJSON(d256_2, str_d2); + auto scalar128_d1_promoted = ScalarFromJSON(d128_promoted, str_d1_promoted); + auto scalar128_d2_promoted = ScalarFromJSON(d128_promoted, str_d2_promoted); + auto scalar256_d1_promoted = ScalarFromJSON(d256_promoted, str_d1_promoted); + auto scalar256_d2_promoted = ScalarFromJSON(d256_promoted, str_d2_promoted); + + check_case_when_decimal_promotion(scalar128_1, scalar128_2, scalar128_d1_promoted, + scalar128_d2_promoted); + check_case_when_decimal_promotion(scalar128_1, scalar256_2, scalar256_d1_promoted, + scalar256_d2_promoted); + check_case_when_decimal_promotion(scalar256_1, scalar128_2, scalar256_d1_promoted, + scalar256_d2_promoted); + check_case_when_decimal_promotion(scalar256_1, scalar256_2, scalar256_d1_promoted, + scalar256_d2_promoted); + } + } +} + TEST(TestCaseWhen, FixedSizeBinary) { auto type = fixed_size_binary(3); auto cond_true = ScalarFromJSON(boolean(), "true"); @@ -2509,6 +2577,28 @@ TEST(TestCaseWhen, UnionBoolStringRandom) { } } +TEST(TestCaseWhen, DispatchExact) { + // Decimal types with same (p, s) + CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal128(20, 3), + decimal128(20, 3)}); + CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal256(20, 3), + decimal256(20, 3)}); + + // Decimal types with different (p, s) + CheckDispatchExactFails("case_when", {struct_({field("", boolean())}), + decimal128(20, 3), decimal128(21, 3)}); + CheckDispatchExactFails("case_when", {struct_({field("", boolean())}), + decimal128(20, 1), decimal128(20, 3)}); + CheckDispatchExactFails("case_when", {struct_({field("", boolean())}), + decimal128(20, 3), decimal256(20, 3)}); + CheckDispatchExactFails("case_when", {struct_({field("", boolean())}), + decimal256(20, 3), decimal128(21, 3)}); + CheckDispatchExactFails("case_when", {struct_({field("", boolean())}), + decimal256(20, 3), decimal256(21, 3)}); + CheckDispatchExactFails("case_when", {struct_({field("", boolean())}), + decimal256(20, 1), decimal256(20, 3)}); +} + TEST(TestCaseWhen, DispatchBest) { CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()}, {struct_({field("", boolean())}), int64(), int64()}); @@ -2559,6 +2649,32 @@ TEST(TestCaseWhen, DispatchBest) { CheckDispatchBest( "case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()}, {struct_({field("", boolean())}), utf8(), utf8()}); + + // Decimal promotion + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 3)}, + {struct_({field("", boolean())}), decimal128(21, 3), decimal128(21, 3)}); + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), decimal128(20, 1), decimal128(21, 3)}, + {struct_({field("", boolean())}), decimal128(22, 3), decimal128(22, 3)}); + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 1)}, + {struct_({field("", boolean())}), decimal128(23, 3), decimal128(23, 3)}); + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), decimal128(20, 3), decimal256(21, 3)}, + {struct_({field("", boolean())}), decimal256(21, 3), decimal256(21, 3)}); + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), decimal256(20, 1), decimal128(21, 3)}, + {struct_({field("", boolean())}), decimal256(22, 3), decimal256(22, 3)}); + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), decimal256(20, 3), decimal256(21, 1)}, + {struct_({field("", boolean())}), decimal256(23, 3), decimal256(23, 3)}); } template