Skip to content

Commit 8ccdbe7

Browse files
authored
GH-41336: [C++][Compute] Fix case_when kernel dispatch for decimals with different precisions and scales (#47479)
### Rationale for this change Another case of decimal kernels not able to suppress exact matching when precisions and scales of the arguments differ, causing wrong result type. After #47297, we have a systematic way to do that and guide the matching to go to the "best match" (applying implicit casts). ### What changes are included in this PR? Simply added a constraint match that checks if the precisions and scales of the decimal arguments are the same. Also added corresponding tests in forms of both expression (exact match first, then best match) and function call (best match only). ### Are these changes tested? Yes. ### Are there any user-facing changes? None. * GitHub Issue: #41336 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent a444380 commit 8ccdbe7

File tree

6 files changed

+189
-11
lines changed

6 files changed

+189
-11
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,51 @@ TEST(Expression, BindWithImplicitCasts) {
809809
call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a));
810810
}
811811

812+
TEST(Expression, BindWithImplicitCastsForCaseWhenOnDecimal) {
813+
auto exciting_schema = schema(
814+
{field("a", struct_({field("", boolean())})),
815+
field("dec128_20_3", decimal128(20, 3)), field("dec128_21_3", decimal128(21, 3)),
816+
field("dec128_20_1", decimal128(20, 1)), field("dec128_21_1", decimal128(21, 1)),
817+
field("dec256_20_3", decimal256(20, 3)), field("dec256_21_3", decimal256(21, 3)),
818+
field("dec256_20_1", decimal256(20, 1)), field("dec256_21_1", decimal256(21, 1))});
819+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
820+
field_ref("dec128_21_3")}),
821+
call("case_when",
822+
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(21, 3)),
823+
field_ref("dec128_21_3")}),
824+
/*bound_out=*/nullptr, *exciting_schema);
825+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_1"),
826+
field_ref("dec128_21_3")}),
827+
call("case_when",
828+
{field_ref("a"), cast(field_ref("dec128_20_1"), decimal128(22, 3)),
829+
cast(field_ref("dec128_21_3"), decimal128(22, 3))}),
830+
/*bound_out=*/nullptr, *exciting_schema);
831+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
832+
field_ref("dec128_21_1")}),
833+
call("case_when",
834+
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(23, 3)),
835+
cast(field_ref("dec128_21_1"), decimal128(23, 3))}),
836+
/*bound_out=*/nullptr, *exciting_schema);
837+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
838+
field_ref("dec256_21_3")}),
839+
call("case_when",
840+
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal256(21, 3)),
841+
field_ref("dec256_21_3")}),
842+
/*bound_out=*/nullptr, *exciting_schema);
843+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_1"),
844+
field_ref("dec128_21_3")}),
845+
call("case_when",
846+
{field_ref("a"), cast(field_ref("dec256_20_1"), decimal256(22, 3)),
847+
cast(field_ref("dec128_21_3"), decimal256(22, 3))}),
848+
/*bound_out=*/nullptr, *exciting_schema);
849+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_3"),
850+
field_ref("dec256_21_1")}),
851+
call("case_when",
852+
{field_ref("a"), cast(field_ref("dec256_20_3"), decimal256(23, 3)),
853+
cast(field_ref("dec256_21_1"), decimal256(23, 3))}),
854+
/*bound_out=*/nullptr, *exciting_schema);
855+
}
856+
812857
TEST(Expression, BindNestedCall) {
813858
auto expr = add(field_ref("a"),
814859
call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}),

cpp/src/arrow/compute/kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ std::string OutputType::ToString() const {
478478
// ----------------------------------------------------------------------
479479
// MatchConstraint
480480

481-
std::shared_ptr<MatchConstraint> MakeConstraint(
481+
std::shared_ptr<MatchConstraint> MatchConstraint::Make(
482482
std::function<bool(const std::vector<TypeHolder>&)> matches) {
483483
class FunctionMatchConstraint : public MatchConstraint {
484484
public:

cpp/src/arrow/compute/kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,11 @@ class ARROW_EXPORT MatchConstraint {
356356

357357
/// \brief Return true if the input types satisfy the constraint.
358358
virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
359-
};
360359

361-
/// \brief Convenience function to create a MatchConstraint from a match function.
362-
ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
363-
std::function<bool(const std::vector<TypeHolder>&)> matches);
360+
/// \brief Convenience function to create a MatchConstraint from a match function.
361+
static std::shared_ptr<MatchConstraint> Make(
362+
std::function<bool(const std::vector<TypeHolder>&)> matches);
363+
};
364364

365365
/// \brief Constraint that all input types are decimal types and have the same scale.
366366
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();

cpp/src/arrow/compute/kernel_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,15 @@ TEST(OutputType, Resolve) {
313313
TEST(MatchConstraint, ConvenienceMaker) {
314314
{
315315
auto always_match =
316-
MakeConstraint([](const std::vector<TypeHolder>& types) { return true; });
316+
MatchConstraint::Make([](const std::vector<TypeHolder>& types) { return true; });
317317

318318
ASSERT_TRUE(always_match->Matches({}));
319319
ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
320320
}
321321

322322
{
323323
auto always_false =
324-
MakeConstraint([](const std::vector<TypeHolder>& types) { return false; });
324+
MatchConstraint::Make([](const std::vector<TypeHolder>& types) { return false; });
325325

326326
ASSERT_FALSE(always_false->Matches({}));
327327
ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));

cpp/src/arrow/compute/kernels/scalar_if_else.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,20 @@ struct CaseWhenFunction : ScalarFunction {
14511451
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
14521452
return arrow::compute::detail::NoMatchingKernel(this, *types);
14531453
}
1454+
1455+
static std::shared_ptr<MatchConstraint> DecimalMatchConstraint() {
1456+
static auto constraint =
1457+
MatchConstraint::Make([](const std::vector<TypeHolder>& types) -> bool {
1458+
DCHECK_GE(types.size(), 2);
1459+
DCHECK(std::all_of(types.begin() + 1, types.end(), [](const TypeHolder& type) {
1460+
return is_decimal(type.id());
1461+
}));
1462+
return std::all_of(
1463+
types.begin() + 2, types.end(),
1464+
[&types](const TypeHolder& type) { return type == types[1]; });
1465+
});
1466+
return constraint;
1467+
}
14541468
};
14551469

14561470
// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions
@@ -2712,10 +2726,11 @@ struct ChooseFunction : ScalarFunction {
27122726
};
27132727

27142728
void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function,
2715-
detail::GetTypeId get_id, ArrayKernelExec exec) {
2729+
detail::GetTypeId get_id, ArrayKernelExec exec,
2730+
std::shared_ptr<MatchConstraint> constraint = nullptr) {
27162731
ScalarKernel kernel(
27172732
KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType,
2718-
/*is_varargs=*/true),
2733+
/*is_varargs=*/true, std::move(constraint)),
27192734
exec);
27202735
if (is_fixed_width(get_id.id)) {
27212736
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
@@ -2890,8 +2905,10 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
28902905
AddPrimitiveCaseWhenKernels(func, {boolean(), null(), float16()});
28912906
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
28922907
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
2893-
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
2894-
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
2908+
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec,
2909+
CaseWhenFunction::DecimalMatchConstraint());
2910+
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec,
2911+
CaseWhenFunction::DecimalMatchConstraint());
28952912
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
28962913
AddNestedCaseWhenKernels(func);
28972914
DCHECK_OK(registry->AddFunction(std::move(func)));

cpp/src/arrow/compute/kernels/scalar_if_else_test.cc

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,74 @@ TEST(TestCaseWhen, Decimal) {
18071807
}
18081808
}
18091809

1810+
TEST(TestCaseWhen, DecimalPromotion) {
1811+
auto check_case_when_decimal_promotion =
1812+
[](std::shared_ptr<Scalar> body_true, std::shared_ptr<Scalar> body_false,
1813+
std::shared_ptr<Scalar> promoted_true, std::shared_ptr<Scalar> promoted_false) {
1814+
auto cond_true = ScalarFromJSON(boolean(), "true");
1815+
auto cond_false = ScalarFromJSON(boolean(), "false");
1816+
CheckScalar("case_when", {MakeStruct({cond_true}), body_true, body_false},
1817+
promoted_true);
1818+
CheckScalar("case_when", {MakeStruct({cond_false}), body_true, body_false},
1819+
promoted_false);
1820+
};
1821+
1822+
const std::vector<std::pair<int, int>> precisions = {{10, 20}, {15, 15}, {20, 10}};
1823+
const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
1824+
for (auto p : precisions) {
1825+
for (auto s : scales) {
1826+
auto p1 = p.first;
1827+
auto s1 = s.first;
1828+
auto p2 = p.second;
1829+
auto s2 = s.second;
1830+
1831+
auto max_scale = std::max({s1, s2});
1832+
auto scale_up_1 = max_scale - s1;
1833+
auto scale_up_2 = max_scale - s2;
1834+
auto max_precision = std::max({p1 + scale_up_1, p2 + scale_up_2});
1835+
1836+
// Operand string: 444.777...
1837+
std::string str_d1 =
1838+
R"(")" + std::string(p1 - s1, '4') + "." + std::string(s1, '7') + R"(")";
1839+
std::string str_d2 =
1840+
R"(")" + std::string(p2 - s2, '4') + "." + std::string(s2, '7') + R"(")";
1841+
1842+
// Promoted string: 444.777...000
1843+
std::string str_d1_promoted = R"(")" + std::string(p1 - s1, '4') + "." +
1844+
std::string(s1, '7') +
1845+
std::string(max_scale - s1, '0') + R"(")";
1846+
std::string str_d2_promoted = R"(")" + std::string(p2 - s2, '4') + "." +
1847+
std::string(s2, '7') +
1848+
std::string(max_scale - s2, '0') + R"(")";
1849+
1850+
auto d128_1 = decimal128(p1, s1);
1851+
auto d128_2 = decimal128(p2, s2);
1852+
auto d256_1 = decimal256(p1, s1);
1853+
auto d256_2 = decimal256(p2, s2);
1854+
auto d128_promoted = decimal128(max_precision, max_scale);
1855+
auto d256_promoted = decimal256(max_precision, max_scale);
1856+
1857+
auto scalar128_1 = ScalarFromJSON(d128_1, str_d1);
1858+
auto scalar128_2 = ScalarFromJSON(d128_2, str_d2);
1859+
auto scalar256_1 = ScalarFromJSON(d256_1, str_d1);
1860+
auto scalar256_2 = ScalarFromJSON(d256_2, str_d2);
1861+
auto scalar128_d1_promoted = ScalarFromJSON(d128_promoted, str_d1_promoted);
1862+
auto scalar128_d2_promoted = ScalarFromJSON(d128_promoted, str_d2_promoted);
1863+
auto scalar256_d1_promoted = ScalarFromJSON(d256_promoted, str_d1_promoted);
1864+
auto scalar256_d2_promoted = ScalarFromJSON(d256_promoted, str_d2_promoted);
1865+
1866+
check_case_when_decimal_promotion(scalar128_1, scalar128_2, scalar128_d1_promoted,
1867+
scalar128_d2_promoted);
1868+
check_case_when_decimal_promotion(scalar128_1, scalar256_2, scalar256_d1_promoted,
1869+
scalar256_d2_promoted);
1870+
check_case_when_decimal_promotion(scalar256_1, scalar128_2, scalar256_d1_promoted,
1871+
scalar256_d2_promoted);
1872+
check_case_when_decimal_promotion(scalar256_1, scalar256_2, scalar256_d1_promoted,
1873+
scalar256_d2_promoted);
1874+
}
1875+
}
1876+
}
1877+
18101878
TEST(TestCaseWhen, FixedSizeBinary) {
18111879
auto type = fixed_size_binary(3);
18121880
auto cond_true = ScalarFromJSON(boolean(), "true");
@@ -2509,6 +2577,28 @@ TEST(TestCaseWhen, UnionBoolStringRandom) {
25092577
}
25102578
}
25112579

2580+
TEST(TestCaseWhen, DispatchExact) {
2581+
// Decimal types with same (p, s)
2582+
CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal128(20, 3),
2583+
decimal128(20, 3)});
2584+
CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal256(20, 3),
2585+
decimal256(20, 3)});
2586+
2587+
// Decimal types with different (p, s)
2588+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2589+
decimal128(20, 3), decimal128(21, 3)});
2590+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2591+
decimal128(20, 1), decimal128(20, 3)});
2592+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2593+
decimal128(20, 3), decimal256(20, 3)});
2594+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2595+
decimal256(20, 3), decimal128(21, 3)});
2596+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2597+
decimal256(20, 3), decimal256(21, 3)});
2598+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2599+
decimal256(20, 1), decimal256(20, 3)});
2600+
}
2601+
25122602
TEST(TestCaseWhen, DispatchBest) {
25132603
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
25142604
{struct_({field("", boolean())}), int64(), int64()});
@@ -2559,6 +2649,32 @@ TEST(TestCaseWhen, DispatchBest) {
25592649
CheckDispatchBest(
25602650
"case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()},
25612651
{struct_({field("", boolean())}), utf8(), utf8()});
2652+
2653+
// Decimal promotion
2654+
CheckDispatchBest(
2655+
"case_when",
2656+
{struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 3)},
2657+
{struct_({field("", boolean())}), decimal128(21, 3), decimal128(21, 3)});
2658+
CheckDispatchBest(
2659+
"case_when",
2660+
{struct_({field("", boolean())}), decimal128(20, 1), decimal128(21, 3)},
2661+
{struct_({field("", boolean())}), decimal128(22, 3), decimal128(22, 3)});
2662+
CheckDispatchBest(
2663+
"case_when",
2664+
{struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 1)},
2665+
{struct_({field("", boolean())}), decimal128(23, 3), decimal128(23, 3)});
2666+
CheckDispatchBest(
2667+
"case_when",
2668+
{struct_({field("", boolean())}), decimal128(20, 3), decimal256(21, 3)},
2669+
{struct_({field("", boolean())}), decimal256(21, 3), decimal256(21, 3)});
2670+
CheckDispatchBest(
2671+
"case_when",
2672+
{struct_({field("", boolean())}), decimal256(20, 1), decimal128(21, 3)},
2673+
{struct_({field("", boolean())}), decimal256(22, 3), decimal256(22, 3)});
2674+
CheckDispatchBest(
2675+
"case_when",
2676+
{struct_({field("", boolean())}), decimal256(20, 3), decimal256(21, 1)},
2677+
{struct_({field("", boolean())}), decimal256(23, 3), decimal256(23, 3)});
25622678
}
25632679

25642680
template <typename Type>

0 commit comments

Comments
 (0)