-
Notifications
You must be signed in to change notification settings - Fork 4.1k
GH-41336: [C++][Compute] Fix the bug of decimal types skipping cast in IfElse related expression function calls #41363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1195,6 +1195,43 @@ struct ResolveIfElseExec<NullType, AllocateMem> { | |
| } | ||
| }; | ||
|
|
||
| template <typename ResolverForOtherTypes> | ||
| Result<TypeHolder> ResolveDecimalCaseType(KernelContext* ctx, | ||
| const std::vector<TypeHolder>& types, | ||
| ResolverForOtherTypes&& resolver) { | ||
| if (!HasDecimal(types)) { | ||
| return resolver(ctx, types); | ||
| } | ||
|
|
||
| int32_t max_precision = 0, max_scale = 0; | ||
| for (auto& type : types) { | ||
| if (is_floating(type.id()) || is_integer(type.id())) { | ||
| return Status::Invalid("Need to cast numeric types containing decimal types"); | ||
|
Comment on lines
+1208
to
+1209
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a better approach for these failures is to have kernels that fail. You write input matchers to cover all cases and dispatch to the fail kernels for types that can't be handled without casts.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the approach with |
||
| } else if (is_decimal(type.id())) { | ||
| const auto& decimal_type = checked_cast<const arrow::DecimalType&>(*type); | ||
| if (decimal_type.precision() < max_precision || decimal_type.scale() < max_scale) { | ||
|
||
| return Status::Invalid("Need to cast decimal types"); | ||
| } | ||
| max_precision = std::max(max_precision, decimal_type.precision()); | ||
| max_scale = std::max(max_scale, decimal_type.scale()); | ||
| } else { | ||
| // Do nothing, needn't cast | ||
| } | ||
| } | ||
|
|
||
| return resolver(ctx, types); | ||
| } | ||
|
|
||
| Result<TypeHolder> ResolveCoalesceOutputType(KernelContext* ctx, | ||
| const std::vector<TypeHolder>& types) { | ||
| return ResolveDecimalCaseType(ctx, types, FirstType); | ||
| } | ||
|
|
||
| Result<TypeHolder> ResolveOutputType(KernelContext* ctx, | ||
| const std::vector<TypeHolder>& types) { | ||
| return ResolveDecimalCaseType(ctx, types, LastType); | ||
| } | ||
|
|
||
| struct IfElseFunction : ScalarFunction { | ||
| using ScalarFunction::ScalarFunction; | ||
|
|
||
|
|
@@ -1299,7 +1336,8 @@ void AddBinaryIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_functi | |
| template <typename T> | ||
| void AddFixedWidthIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_function) { | ||
| auto type_id = T::type_id; | ||
| ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, LastType, | ||
| ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, | ||
| ResolveOutputType, | ||
| ResolveIfElseExec<T, /*AllocateMem=*/std::false_type>::Exec); | ||
| kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; | ||
| kernel.mem_allocation = MemAllocation::PREALLOCATE; | ||
|
|
@@ -2681,7 +2719,8 @@ struct ChooseFunction : ScalarFunction { | |
| void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function, | ||
| detail::GetTypeId get_id, ArrayKernelExec exec) { | ||
| ScalarKernel kernel( | ||
| KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType, | ||
| KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, | ||
| ResolveOutputType, | ||
| /*is_varargs=*/true), | ||
| exec); | ||
| if (is_fixed_width(get_id.id)) { | ||
|
|
@@ -2714,9 +2753,10 @@ void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_fu | |
|
|
||
| void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function, | ||
| detail::GetTypeId get_id, ArrayKernelExec exec) { | ||
| ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, FirstType, | ||
| /*is_varargs=*/true), | ||
| exec); | ||
| ScalarKernel kernel( | ||
| KernelSignature::Make({InputType(get_id.id)}, ResolveCoalesceOutputType, | ||
| /*is_varargs=*/true), | ||
| exec); | ||
| kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; | ||
| kernel.mem_allocation = MemAllocation::PREALLOCATE; | ||
| kernel.can_write_into_slices = is_fixed_width(get_id.id); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name suggestion:
ResolveDecimalCaseWhenOutputTypeBecause "DecimalCase" is ambiguous by "case" being a word that would fit here, but it actually refers to the kernel function called "CaseWhen".