Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,112 @@ TEST(Expression, ExecuteCall) {
])"));
}

TEST(Expression, ExecuteCallWithDecimalIfElseOps) {
// GH-41336 : cast failed in 'IfElse' related kernel functions, make sure
// the expression call's results same with it's 'CallFunction'.
//
// if_else
{
// decimal128(3,2), decimal128(4,1) --> output_type : decimal128(5,2)
ExpectExecute(
call("if_else", {field_ref("b1"), field_ref("a1"), field_ref("a2")}),
ArrayFromJSON(struct_({field("b1", boolean()), field("a1", decimal128(3, 2)),
field("a2", decimal128(4, 1))}),
R"([
{"b1": true, "a1": "1.23", "a2": "121.3"},
{"b1": true, "a1": "2.34", "a2": "-232.3"},
{"b1": false, "a1": "-1.23", "a2": "0.0"}
])"));

// decimal256(3,2), decimal128(3,2) --> output_type : decimal256(3,2)
ExpectExecute(
call("if_else", {field_ref("b1"), field_ref("a1"), field_ref("a2")}),
ArrayFromJSON(struct_({field("b1", boolean()), field("a1", decimal256(3, 2)),
field("a2", decimal128(3, 2))}),
R"([
{"b1": true, "a1": "1.23", "a2": "1.34"},
{"b1": true, "a1": "2.34", "a2": "-2.34"},
{"b1": false, "a1": "-1.23", "a2": "0.00"}
])"));

// decimal256(3,2), decimal128(4,1) --> output_type : decimal256(5,2)
ExpectExecute(
call("if_else", {field_ref("b1"), field_ref("a1"), field_ref("a2")}),
ArrayFromJSON(struct_({field("b1", boolean()), field("a1", decimal128(3, 2)),
field("a2", decimal128(4, 1))}),
R"([
{"b1": true, "a1": "1.23", "a2": "121.3"},
{"b1": true, "a1": "2.34", "a2": "-232.3"},
{"b1": false, "a1": "-1.23", "a2": "0.0"}
])"));
}

// case_when
{
// decimal128(4,2), decimal128(3,2) --> output_type : decimal128(4,2)
ExpectExecute(call("case_when", {field_ref("c"), field_ref("a1"), field_ref("a2")}),
ArrayFromJSON(struct_({field("c", struct_({field("m", boolean())})),
field("a1", decimal128(4, 2)),
field("a2", decimal128(3, 2))}),
R"([
{ "c": {"m": true}, "a1": "1.23", "a2": "1.34"},
{ "c": {"m": false}, "a1": "2.34", "a2": "-2.34"}
])"));

// decimal128(4,1), decimal128(3,3), decimal256(2,1) --> output_type: decimal256(6,3)
ExpectExecute(
call("case_when",
{field_ref("c"), field_ref("a1"), field_ref("a2"), field_ref("a3")}),
ArrayFromJSON(
struct_(
{field("c", struct_({field("m1", boolean()), field("m2", boolean())})),
field("a1", decimal128(4, 2)), field("a2", decimal128(3, 3)),
field("a3", decimal256(2, 1))}),
R"([
{ "c": {"m1": true, "m2": false}, "a1": "1.23", "a2": "1.342", "a3": "0.0"},
{ "c": {"m1": true, "m2": false}, "a1": "2.34", "a2": "-2.314", "a3": "3.1"},
{ "c": {"m1": null, "m2": true}, "a1": "2.34", "a2": null, "a3": "3.1"},
{ "c": {"m1": null, "m2": null}, "a1": "2.34", "a2": "-2.034", "a3": "3.1"}
])"));

// int32(), float32(), decimal256(2,1) --> output_type: decimal256(2,1)
ExpectExecute(call("case_when", {field_ref("c"), field_ref("a1"), field_ref("a2"),
field_ref("a3")}),
ArrayFromJSON(struct_({field("c", struct_({field("m1", boolean()),
field("m2", boolean())})),
field("a1", int32()), field("a2", float32()),
field("a3", decimal256(2, 1))}),
R"([
{ "c": {"m1": true, "m2": false}, "a1": 1, "a2": 1.342, "a3": "0.0"},
{ "c": {"m1": false, "m2": false}, "a1": 34, "a2": 2.314, "a3": "3.1"}
])"));
}

// coalesce
{
// decimal128(4,1), decimal128(3,3), decimal256(2,1) --> output_type: decimal256(6,3)
ExpectExecute(
call("coalesce", {field_ref("a1"), field_ref("a2"), field_ref("a3")}),
ArrayFromJSON(struct_({field("a1", decimal(4, 1)), field("a2", decimal128(3, 3)),
field("a3", decimal128(2, 1))}),
R"([
{"a1": null, "a2": "1.123", "a3": "2.3"},
{"a1": null, "a2": null, "a3": "-3.3"},
{"a1": "45.3", "a2": "-1.230", "a3": "0.0"}
])"));

// decimal128(4,1), int64(), decimal256(2,2) --> output_type: decimal256(5,1)
ExpectExecute(call("coalesce", {field_ref("a1"), field_ref("a2"), field_ref("a3")}),
ArrayFromJSON(struct_({field("a1", decimal(4, 1)), field("a2", int64()),
field("a3", decimal128(2, 2))}),
R"([
{"a1": null, "a2": 123, "a3": "2.31"},
{"a1": null, "a2": null, "a3": "-3.03"},
{"a1": "45.3", "a2": -30, "a3": "0.00"}
])"));
}
}

TEST(Expression, ExecuteCallWithNoArguments) {
const int kCount = 10;
auto random_options = RandomOptions::FromSeed(/*seed=*/0);
Expand Down
50 changes: 45 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,43 @@ struct ResolveIfElseExec<NullType, AllocateMem> {
}
};

template <typename ResolverForOtherTypes>
Result<TypeHolder> ResolveDecimalCaseType(KernelContext* ctx,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name suggestion: ResolveDecimalCaseWhenOutputType

Because "DecimalCase" is ambiguous by "case" being a word that would fit here, but it actually refers to the kernel function called "CaseWhen".

const std::vector<TypeHolder>& types,
ResolverForOtherTypes&& resolver) {
if (!HasDecimal(types)) {
return resolver(ctx, types);
}

int32_t max_precision = 0, max_scale = 0;
for (auto& type : types) {
if (is_floating(type.id()) || is_integer(type.id())) {
return Status::Invalid("Need to cast numeric types containing decimal types");
Comment on lines +1208 to +1209
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better approach for these failures is to have kernels that fail. You write input matchers to cover all cases and dispatch to the fail kernels for types that can't be handled without casts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the approach with TypeMatcher is more suitable. If there is no problem with the PR #41012 that contains a new matcher interface for combination-types, current PR will continue to be promoted.

} else if (is_decimal(type.id())) {
const auto& decimal_type = checked_cast<const arrow::DecimalType&>(*type);
if (decimal_type.precision() < max_precision || decimal_type.scale() < max_scale) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< 0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned invalid status used in the logic :
https://github.com/bkietz/arrow/blob/c25866bc59e30e43e8dc3b05f3973d48074c3594/cpp/src/arrow/compute/expression.cc#L551-L552

Let the expression call go into DispatchBest so they can be cast correctly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In BindNonRecursive:
FinishBind will do output type resolve. For the first time go into the resolver, our decimal related types haven't cast, we could check if the decimal types have different precision and scale.

And then return invalid to make the BindNonRecursive go into IfElse's related function's DispatchBest and do cast.

After that, when we go into type resolver in BindNoRecursive second time, we could find the decimal types already cast to same precision and scale, so we could find the correct result output type.

return Status::Invalid("Need to cast decimal types");
}
max_precision = std::max(max_precision, decimal_type.precision());
max_scale = std::max(max_scale, decimal_type.scale());
} else {
// Do nothing, needn't cast
}
}

return resolver(ctx, types);
}

Result<TypeHolder> ResolveCoalesceOutputType(KernelContext* ctx,
const std::vector<TypeHolder>& types) {
return ResolveDecimalCaseType(ctx, types, FirstType);
}

Result<TypeHolder> ResolveOutputType(KernelContext* ctx,
const std::vector<TypeHolder>& types) {
return ResolveDecimalCaseType(ctx, types, LastType);
}

struct IfElseFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Expand Down Expand Up @@ -1299,7 +1336,8 @@ void AddBinaryIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_functi
template <typename T>
void AddFixedWidthIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_function) {
auto type_id = T::type_id;
ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, LastType,
ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)},
ResolveOutputType,
ResolveIfElseExec<T, /*AllocateMem=*/std::false_type>::Exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
Expand Down Expand Up @@ -2681,7 +2719,8 @@ struct ChooseFunction : ScalarFunction {
void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(
KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType,
KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)},
ResolveOutputType,
/*is_varargs=*/true),
exec);
if (is_fixed_width(get_id.id)) {
Expand Down Expand Up @@ -2714,9 +2753,10 @@ void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_fu

void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, FirstType,
/*is_varargs=*/true),
exec);
ScalarKernel kernel(
KernelSignature::Make({InputType(get_id.id)}, ResolveCoalesceOutputType,
/*is_varargs=*/true),
exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
kernel.can_write_into_slices = is_fixed_width(get_id.id);
Expand Down