Skip to content

Commit 48e52c1

Browse files
committed
Fix case_when kernel dispatch for decimals with different (p, s)
1 parent 6f6138b commit 48e52c1

File tree

6 files changed

+192
-11
lines changed

6 files changed

+192
-11
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,51 @@ TEST(Expression, BindWithImplicitCasts) {
717717
call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a));
718718
}
719719

720+
TEST(Expression, BindWithImplicitCastsForCaseWhenOnDecimal) {
721+
auto exciting_schema = schema(
722+
{field("a", struct_({field("", boolean())})),
723+
field("dec128_20_3", decimal128(20, 3)), field("dec128_21_3", decimal128(21, 3)),
724+
field("dec128_20_1", decimal128(20, 1)), field("dec128_21_1", decimal128(21, 1)),
725+
field("dec256_20_3", decimal256(20, 3)), field("dec256_21_3", decimal256(21, 3)),
726+
field("dec256_20_1", decimal256(20, 1)), field("dec256_21_1", decimal256(21, 1))});
727+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
728+
field_ref("dec128_21_3")}),
729+
call("case_when",
730+
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(21, 3)),
731+
field_ref("dec128_21_3")}),
732+
/*bound_out=*/nullptr, *exciting_schema);
733+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_1"),
734+
field_ref("dec128_21_3")}),
735+
call("case_when",
736+
{field_ref("a"), cast(field_ref("dec128_20_1"), decimal128(22, 3)),
737+
cast(field_ref("dec128_21_3"), decimal128(22, 3))}),
738+
/*bound_out=*/nullptr, *exciting_schema);
739+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
740+
field_ref("dec128_21_1")}),
741+
call("case_when",
742+
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(23, 3)),
743+
cast(field_ref("dec128_21_1"), decimal128(23, 3))}),
744+
/*bound_out=*/nullptr, *exciting_schema);
745+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
746+
field_ref("dec256_21_3")}),
747+
call("case_when",
748+
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal256(21, 3)),
749+
field_ref("dec256_21_3")}),
750+
/*bound_out=*/nullptr, *exciting_schema);
751+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_1"),
752+
field_ref("dec128_21_3")}),
753+
call("case_when",
754+
{field_ref("a"), cast(field_ref("dec256_20_1"), decimal256(22, 3)),
755+
cast(field_ref("dec128_21_3"), decimal256(22, 3))}),
756+
/*bound_out=*/nullptr, *exciting_schema);
757+
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_3"),
758+
field_ref("dec256_21_1")}),
759+
call("case_when",
760+
{field_ref("a"), cast(field_ref("dec256_20_3"), decimal256(23, 3)),
761+
cast(field_ref("dec256_21_1"), decimal256(23, 3))}),
762+
/*bound_out=*/nullptr, *exciting_schema);
763+
}
764+
720765
TEST(Expression, BindNestedCall) {
721766
auto expr = add(field_ref("a"),
722767
call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}),

cpp/src/arrow/compute/kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ std::string OutputType::ToString() const {
478478
// ----------------------------------------------------------------------
479479
// MatchConstraint
480480

481-
std::shared_ptr<MatchConstraint> MakeConstraint(
481+
std::shared_ptr<MatchConstraint> MatchConstraint::Make(
482482
std::function<bool(const std::vector<TypeHolder>&)> matches) {
483483
class FunctionMatchConstraint : public MatchConstraint {
484484
public:

cpp/src/arrow/compute/kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,11 @@ class ARROW_EXPORT MatchConstraint {
356356

357357
/// \brief Return true if the input types satisfy the constraint.
358358
virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
359-
};
360359

361-
/// \brief Convenience function to create a MatchConstraint from a match function.
362-
ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
363-
std::function<bool(const std::vector<TypeHolder>&)> matches);
360+
/// \brief Convenience function to create a MatchConstraint from a match function.
361+
static std::shared_ptr<MatchConstraint> Make(
362+
std::function<bool(const std::vector<TypeHolder>&)> matches);
363+
};
364364

365365
/// \brief Constraint that all input types are decimal types and have the same scale.
366366
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();

cpp/src/arrow/compute/kernel_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,15 @@ TEST(OutputType, Resolve) {
313313
TEST(MatchConstraint, ConvenienceMaker) {
314314
{
315315
auto always_match =
316-
MakeConstraint([](const std::vector<TypeHolder>& types) { return true; });
316+
MatchConstraint::Make([](const std::vector<TypeHolder>& types) { return true; });
317317

318318
ASSERT_TRUE(always_match->Matches({}));
319319
ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
320320
}
321321

322322
{
323323
auto always_false =
324-
MakeConstraint([](const std::vector<TypeHolder>& types) { return false; });
324+
MatchConstraint::Make([](const std::vector<TypeHolder>& types) { return false; });
325325

326326
ASSERT_FALSE(always_false->Matches({}));
327327
ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));

cpp/src/arrow/compute/kernels/scalar_if_else.cc

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,23 @@ struct CaseWhenFunction : ScalarFunction {
14511451
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
14521452
return arrow::compute::detail::NoMatchingKernel(this, *types);
14531453
}
1454+
1455+
static std::shared_ptr<MatchConstraint> DecimalMatchConstraint() {
1456+
static auto constraint =
1457+
MatchConstraint::Make([](const std::vector<TypeHolder>& types) -> bool {
1458+
DCHECK_GE(types.size(), 3);
1459+
DCHECK(std::all_of(types.begin() + 1, types.end(), [](const TypeHolder& type) {
1460+
return is_decimal(type.id());
1461+
}));
1462+
const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
1463+
return std::all_of(
1464+
types.begin() + 2, types.end(), [&ty1](const TypeHolder& type) {
1465+
const auto& ty = checked_cast<const DecimalType&>(*type.type);
1466+
return ty1.Equals(ty);
1467+
});
1468+
});
1469+
return constraint;
1470+
}
14541471
};
14551472

14561473
// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions
@@ -2712,10 +2729,11 @@ struct ChooseFunction : ScalarFunction {
27122729
};
27132730

27142731
void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function,
2715-
detail::GetTypeId get_id, ArrayKernelExec exec) {
2732+
detail::GetTypeId get_id, ArrayKernelExec exec,
2733+
std::shared_ptr<MatchConstraint> constraint = nullptr) {
27162734
ScalarKernel kernel(
27172735
KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType,
2718-
/*is_varargs=*/true),
2736+
/*is_varargs=*/true, std::move(constraint)),
27192737
exec);
27202738
if (is_fixed_width(get_id.id)) {
27212739
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
@@ -2890,8 +2908,10 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
28902908
AddPrimitiveCaseWhenKernels(func, {boolean(), null(), float16()});
28912909
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
28922910
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
2893-
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
2894-
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
2911+
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec,
2912+
CaseWhenFunction::DecimalMatchConstraint());
2913+
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec,
2914+
CaseWhenFunction::DecimalMatchConstraint());
28952915
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
28962916
AddNestedCaseWhenKernels(func);
28972917
DCHECK_OK(registry->AddFunction(std::move(func)));

cpp/src/arrow/compute/kernels/scalar_if_else_test.cc

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,74 @@ TEST(TestCaseWhen, Decimal) {
18071807
}
18081808
}
18091809

1810+
TEST(TestCaseWhen, DecimalPromotion) {
1811+
auto check_case_when_decimal_promotion =
1812+
[](std::shared_ptr<Scalar> body_true, std::shared_ptr<Scalar> body_false,
1813+
std::shared_ptr<Scalar> promoted_true, std::shared_ptr<Scalar> promoted_false) {
1814+
auto cond_true = ScalarFromJSON(boolean(), "true");
1815+
auto cond_false = ScalarFromJSON(boolean(), "false");
1816+
CheckScalar("case_when", {MakeStruct({cond_true}), body_true, body_false},
1817+
promoted_true);
1818+
CheckScalar("case_when", {MakeStruct({cond_false}), body_true, body_false},
1819+
promoted_false);
1820+
};
1821+
1822+
const std::vector<std::pair<int, int>> precisions = {{10, 20}, {15, 15}, {20, 10}};
1823+
const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
1824+
for (auto p : precisions) {
1825+
for (auto s : scales) {
1826+
auto p1 = p.first;
1827+
auto s1 = s.first;
1828+
auto p2 = p.second;
1829+
auto s2 = s.second;
1830+
1831+
auto max_scale = std::max({s1, s2});
1832+
auto scale_up_1 = max_scale - s1;
1833+
auto scale_up_2 = max_scale - s2;
1834+
auto max_precision = std::max({p1 + scale_up_1, p2 + scale_up_2});
1835+
1836+
// Operand string: 444.777...
1837+
std::string str_d1 =
1838+
R"(")" + std::string(p1 - s1, '4') + "." + std::string(s1, '7') + R"(")";
1839+
std::string str_d2 =
1840+
R"(")" + std::string(p2 - s2, '4') + "." + std::string(s2, '7') + R"(")";
1841+
1842+
// Promoted string: 444.777...000
1843+
std::string str_d1_promoted = R"(")" + std::string(p1 - s1, '4') + "." +
1844+
std::string(s1, '7') +
1845+
std::string(max_scale - s1, '0') + R"(")";
1846+
std::string str_d2_promoted = R"(")" + std::string(p2 - s2, '4') + "." +
1847+
std::string(s2, '7') +
1848+
std::string(max_scale - s2, '0') + R"(")";
1849+
1850+
auto d128_1 = decimal128(p1, s1);
1851+
auto d128_2 = decimal128(p2, s2);
1852+
auto d256_1 = decimal256(p1, s1);
1853+
auto d256_2 = decimal256(p2, s2);
1854+
auto d128_promoted = decimal128(max_precision, max_scale);
1855+
auto d256_promoted = decimal256(max_precision, max_scale);
1856+
1857+
auto scalar128_1 = ScalarFromJSON(d128_1, str_d1);
1858+
auto scalar128_2 = ScalarFromJSON(d128_2, str_d2);
1859+
auto scalar256_1 = ScalarFromJSON(d256_1, str_d1);
1860+
auto scalar256_2 = ScalarFromJSON(d256_2, str_d2);
1861+
auto scalar128_d1_promoted = ScalarFromJSON(d128_promoted, str_d1_promoted);
1862+
auto scalar128_d2_promoted = ScalarFromJSON(d128_promoted, str_d2_promoted);
1863+
auto scalar256_d1_promoted = ScalarFromJSON(d256_promoted, str_d1_promoted);
1864+
auto scalar256_d2_promoted = ScalarFromJSON(d256_promoted, str_d2_promoted);
1865+
1866+
check_case_when_decimal_promotion(scalar128_1, scalar128_2, scalar128_d1_promoted,
1867+
scalar128_d2_promoted);
1868+
check_case_when_decimal_promotion(scalar128_1, scalar256_2, scalar256_d1_promoted,
1869+
scalar256_d2_promoted);
1870+
check_case_when_decimal_promotion(scalar256_1, scalar128_2, scalar256_d1_promoted,
1871+
scalar256_d2_promoted);
1872+
check_case_when_decimal_promotion(scalar256_1, scalar256_2, scalar256_d1_promoted,
1873+
scalar256_d2_promoted);
1874+
}
1875+
}
1876+
}
1877+
18101878
TEST(TestCaseWhen, FixedSizeBinary) {
18111879
auto type = fixed_size_binary(3);
18121880
auto cond_true = ScalarFromJSON(boolean(), "true");
@@ -2509,6 +2577,28 @@ TEST(TestCaseWhen, UnionBoolStringRandom) {
25092577
}
25102578
}
25112579

2580+
TEST(TestCaseWhen, DispatchExact) {
2581+
// Decimal types with same (p, s)
2582+
CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal128(20, 3),
2583+
decimal128(20, 3)});
2584+
CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal256(20, 3),
2585+
decimal256(20, 3)});
2586+
2587+
// Decimal types with different (p, s)
2588+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2589+
decimal128(20, 3), decimal128(21, 3)});
2590+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2591+
decimal128(20, 1), decimal128(20, 3)});
2592+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2593+
decimal128(20, 3), decimal256(20, 3)});
2594+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2595+
decimal256(20, 3), decimal128(21, 3)});
2596+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2597+
decimal256(20, 3), decimal256(21, 3)});
2598+
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
2599+
decimal256(20, 1), decimal256(20, 3)});
2600+
}
2601+
25122602
TEST(TestCaseWhen, DispatchBest) {
25132603
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
25142604
{struct_({field("", boolean())}), int64(), int64()});
@@ -2559,6 +2649,32 @@ TEST(TestCaseWhen, DispatchBest) {
25592649
CheckDispatchBest(
25602650
"case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()},
25612651
{struct_({field("", boolean())}), utf8(), utf8()});
2652+
2653+
// Decimal promotion
2654+
CheckDispatchBest(
2655+
"case_when",
2656+
{struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 3)},
2657+
{struct_({field("", boolean())}), decimal128(21, 3), decimal128(21, 3)});
2658+
CheckDispatchBest(
2659+
"case_when",
2660+
{struct_({field("", boolean())}), decimal128(20, 1), decimal128(21, 3)},
2661+
{struct_({field("", boolean())}), decimal128(22, 3), decimal128(22, 3)});
2662+
CheckDispatchBest(
2663+
"case_when",
2664+
{struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 1)},
2665+
{struct_({field("", boolean())}), decimal128(23, 3), decimal128(23, 3)});
2666+
CheckDispatchBest(
2667+
"case_when",
2668+
{struct_({field("", boolean())}), decimal128(20, 3), decimal256(21, 3)},
2669+
{struct_({field("", boolean())}), decimal256(21, 3), decimal256(21, 3)});
2670+
CheckDispatchBest(
2671+
"case_when",
2672+
{struct_({field("", boolean())}), decimal256(20, 1), decimal128(21, 3)},
2673+
{struct_({field("", boolean())}), decimal256(22, 3), decimal256(22, 3)});
2674+
CheckDispatchBest(
2675+
"case_when",
2676+
{struct_({field("", boolean())}), decimal256(20, 3), decimal256(21, 1)},
2677+
{struct_({field("", boolean())}), decimal256(23, 3), decimal256(23, 3)});
25622678
}
25632679

25642680
template <typename Type>

0 commit comments

Comments
 (0)