Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,51 @@ TEST(Expression, BindWithImplicitCasts) {
call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a));
}

TEST(Expression, BindWithImplicitCastsForCaseWhenOnDecimal) {
auto exciting_schema = schema(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm excited 🤩

{field("a", struct_({field("", boolean())})),
field("dec128_20_3", decimal128(20, 3)), field("dec128_21_3", decimal128(21, 3)),
field("dec128_20_1", decimal128(20, 1)), field("dec128_21_1", decimal128(21, 1)),
field("dec256_20_3", decimal256(20, 3)), field("dec256_21_3", decimal256(21, 3)),
field("dec256_20_1", decimal256(20, 1)), field("dec256_21_1", decimal256(21, 1))});
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
field_ref("dec128_21_3")}),
call("case_when",
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(21, 3)),
field_ref("dec128_21_3")}),
/*bound_out=*/nullptr, *exciting_schema);
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_1"),
field_ref("dec128_21_3")}),
call("case_when",
{field_ref("a"), cast(field_ref("dec128_20_1"), decimal128(22, 3)),
cast(field_ref("dec128_21_3"), decimal128(22, 3))}),
/*bound_out=*/nullptr, *exciting_schema);
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
field_ref("dec128_21_1")}),
call("case_when",
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal128(23, 3)),
cast(field_ref("dec128_21_1"), decimal128(23, 3))}),
/*bound_out=*/nullptr, *exciting_schema);
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
field_ref("dec256_21_3")}),
call("case_when",
{field_ref("a"), cast(field_ref("dec128_20_3"), decimal256(21, 3)),
field_ref("dec256_21_3")}),
/*bound_out=*/nullptr, *exciting_schema);
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_1"),
field_ref("dec128_21_3")}),
call("case_when",
{field_ref("a"), cast(field_ref("dec256_20_1"), decimal256(22, 3)),
cast(field_ref("dec128_21_3"), decimal256(22, 3))}),
/*bound_out=*/nullptr, *exciting_schema);
ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_3"),
field_ref("dec256_21_1")}),
call("case_when",
{field_ref("a"), cast(field_ref("dec256_20_3"), decimal256(23, 3)),
cast(field_ref("dec256_21_1"), decimal256(23, 3))}),
/*bound_out=*/nullptr, *exciting_schema);
}

TEST(Expression, BindNestedCall) {
auto expr = add(field_ref("a"),
call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}),
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ std::string OutputType::ToString() const {
// ----------------------------------------------------------------------
// MatchConstraint

std::shared_ptr<MatchConstraint> MakeConstraint(
std::shared_ptr<MatchConstraint> MatchConstraint::Make(
std::function<bool(const std::vector<TypeHolder>&)> matches) {
class FunctionMatchConstraint : public MatchConstraint {
public:
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ class ARROW_EXPORT MatchConstraint {

/// \brief Return true if the input types satisfy the constraint.
virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
};

/// \brief Convenience function to create a MatchConstraint from a match function.
ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
std::function<bool(const std::vector<TypeHolder>&)> matches);
/// \brief Convenience function to create a MatchConstraint from a match function.
static std::shared_ptr<MatchConstraint> Make(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really related to the issue. Just a small refinement.

std::function<bool(const std::vector<TypeHolder>&)> matches);
};

/// \brief Constraint that all input types are decimal types and have the same scale.
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,15 @@ TEST(OutputType, Resolve) {
TEST(MatchConstraint, ConvenienceMaker) {
{
auto always_match =
MakeConstraint([](const std::vector<TypeHolder>& types) { return true; });
MatchConstraint::Make([](const std::vector<TypeHolder>& types) { return true; });

ASSERT_TRUE(always_match->Matches({}));
ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
}

{
auto always_false =
MakeConstraint([](const std::vector<TypeHolder>& types) { return false; });
MatchConstraint::Make([](const std::vector<TypeHolder>& types) { return false; });

ASSERT_FALSE(always_false->Matches({}));
ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));
Expand Down
25 changes: 21 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,20 @@ struct CaseWhenFunction : ScalarFunction {
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *types);
}

static std::shared_ptr<MatchConstraint> DecimalMatchConstraint() {
static auto constraint =
MatchConstraint::Make([](const std::vector<TypeHolder>& types) -> bool {
DCHECK_GE(types.size(), 2);
DCHECK(std::all_of(types.begin() + 1, types.end(), [](const TypeHolder& type) {
return is_decimal(type.id());
}));
return std::all_of(
types.begin() + 2, types.end(),
[&types](const TypeHolder& type) { return type == types[1]; });
});
return constraint;
}
};

// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions
Expand Down Expand Up @@ -2712,10 +2726,11 @@ struct ChooseFunction : ScalarFunction {
};

void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
detail::GetTypeId get_id, ArrayKernelExec exec,
std::shared_ptr<MatchConstraint> constraint = nullptr) {
ScalarKernel kernel(
KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType,
/*is_varargs=*/true),
/*is_varargs=*/true, std::move(constraint)),
exec);
if (is_fixed_width(get_id.id)) {
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
Expand Down Expand Up @@ -2890,8 +2905,10 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCaseWhenKernels(func, {boolean(), null(), float16()});
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec,
CaseWhenFunction::DecimalMatchConstraint());
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec,
CaseWhenFunction::DecimalMatchConstraint());
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
AddNestedCaseWhenKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
Expand Down
116 changes: 116 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,74 @@ TEST(TestCaseWhen, Decimal) {
}
}

TEST(TestCaseWhen, DecimalPromotion) {
auto check_case_when_decimal_promotion =
[](std::shared_ptr<Scalar> body_true, std::shared_ptr<Scalar> body_false,
std::shared_ptr<Scalar> promoted_true, std::shared_ptr<Scalar> promoted_false) {
auto cond_true = ScalarFromJSON(boolean(), "true");
auto cond_false = ScalarFromJSON(boolean(), "false");
CheckScalar("case_when", {MakeStruct({cond_true}), body_true, body_false},
promoted_true);
CheckScalar("case_when", {MakeStruct({cond_false}), body_true, body_false},
promoted_false);
};

const std::vector<std::pair<int, int>> precisions = {{10, 20}, {15, 15}, {20, 10}};
const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
for (auto p : precisions) {
for (auto s : scales) {
auto p1 = p.first;
auto s1 = s.first;
auto p2 = p.second;
auto s2 = s.second;

auto max_scale = std::max({s1, s2});
auto scale_up_1 = max_scale - s1;
auto scale_up_2 = max_scale - s2;
auto max_precision = std::max({p1 + scale_up_1, p2 + scale_up_2});

// Operand string: 444.777...
std::string str_d1 =
R"(")" + std::string(p1 - s1, '4') + "." + std::string(s1, '7') + R"(")";
std::string str_d2 =
R"(")" + std::string(p2 - s2, '4') + "." + std::string(s2, '7') + R"(")";

// Promoted string: 444.777...000
std::string str_d1_promoted = R"(")" + std::string(p1 - s1, '4') + "." +
std::string(s1, '7') +
std::string(max_scale - s1, '0') + R"(")";
std::string str_d2_promoted = R"(")" + std::string(p2 - s2, '4') + "." +
std::string(s2, '7') +
std::string(max_scale - s2, '0') + R"(")";

auto d128_1 = decimal128(p1, s1);
auto d128_2 = decimal128(p2, s2);
auto d256_1 = decimal256(p1, s1);
auto d256_2 = decimal256(p2, s2);
auto d128_promoted = decimal128(max_precision, max_scale);
auto d256_promoted = decimal256(max_precision, max_scale);

auto scalar128_1 = ScalarFromJSON(d128_1, str_d1);
auto scalar128_2 = ScalarFromJSON(d128_2, str_d2);
auto scalar256_1 = ScalarFromJSON(d256_1, str_d1);
auto scalar256_2 = ScalarFromJSON(d256_2, str_d2);
auto scalar128_d1_promoted = ScalarFromJSON(d128_promoted, str_d1_promoted);
auto scalar128_d2_promoted = ScalarFromJSON(d128_promoted, str_d2_promoted);
auto scalar256_d1_promoted = ScalarFromJSON(d256_promoted, str_d1_promoted);
auto scalar256_d2_promoted = ScalarFromJSON(d256_promoted, str_d2_promoted);

check_case_when_decimal_promotion(scalar128_1, scalar128_2, scalar128_d1_promoted,
scalar128_d2_promoted);
check_case_when_decimal_promotion(scalar128_1, scalar256_2, scalar256_d1_promoted,
scalar256_d2_promoted);
check_case_when_decimal_promotion(scalar256_1, scalar128_2, scalar256_d1_promoted,
scalar256_d2_promoted);
check_case_when_decimal_promotion(scalar256_1, scalar256_2, scalar256_d1_promoted,
scalar256_d2_promoted);
}
}
}

TEST(TestCaseWhen, FixedSizeBinary) {
auto type = fixed_size_binary(3);
auto cond_true = ScalarFromJSON(boolean(), "true");
Expand Down Expand Up @@ -2509,6 +2577,28 @@ TEST(TestCaseWhen, UnionBoolStringRandom) {
}
}

TEST(TestCaseWhen, DispatchExact) {
// Decimal types with same (p, s)
CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal128(20, 3),
decimal128(20, 3)});
CheckDispatchExact("case_when", {struct_({field("", boolean())}), decimal256(20, 3),
decimal256(20, 3)});

// Decimal types with different (p, s)
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
decimal128(20, 3), decimal128(21, 3)});
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
decimal128(20, 1), decimal128(20, 3)});
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
decimal128(20, 3), decimal256(20, 3)});
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
decimal256(20, 3), decimal128(21, 3)});
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
decimal256(20, 3), decimal256(21, 3)});
CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
decimal256(20, 1), decimal256(20, 3)});
}

TEST(TestCaseWhen, DispatchBest) {
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
{struct_({field("", boolean())}), int64(), int64()});
Expand Down Expand Up @@ -2559,6 +2649,32 @@ TEST(TestCaseWhen, DispatchBest) {
CheckDispatchBest(
"case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()},
{struct_({field("", boolean())}), utf8(), utf8()});

// Decimal promotion
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 3)},
{struct_({field("", boolean())}), decimal128(21, 3), decimal128(21, 3)});
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), decimal128(20, 1), decimal128(21, 3)},
{struct_({field("", boolean())}), decimal128(22, 3), decimal128(22, 3)});
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 1)},
{struct_({field("", boolean())}), decimal128(23, 3), decimal128(23, 3)});
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), decimal128(20, 3), decimal256(21, 3)},
{struct_({field("", boolean())}), decimal256(21, 3), decimal256(21, 3)});
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), decimal256(20, 1), decimal128(21, 3)},
{struct_({field("", boolean())}), decimal256(22, 3), decimal256(22, 3)});
CheckDispatchBest(
"case_when",
{struct_({field("", boolean())}), decimal256(20, 3), decimal256(21, 1)},
{struct_({field("", boolean())}), decimal256(23, 3), decimal256(23, 3)});
}

template <typename Type>
Expand Down
Loading