-
Notifications
You must be signed in to change notification settings - Fork 4.1k
GH-41336: [C++][Compute] Fix case_when kernel dispatch for decimals with different precisions and scales #47479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -356,11 +356,11 @@ class ARROW_EXPORT MatchConstraint { | |
|
|
||
| /// \brief Return true if the input types satisfy the constraint. | ||
| virtual bool Matches(const std::vector<TypeHolder>& types) const = 0; | ||
| }; | ||
|
|
||
| /// \brief Convenience function to create a MatchConstraint from a match function. | ||
| ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint( | ||
| std::function<bool(const std::vector<TypeHolder>&)> matches); | ||
| /// \brief Convenience function to create a MatchConstraint from a match function. | ||
| static std::shared_ptr<MatchConstraint> Make( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really related to the issue. Just a small refinement. |
||
| std::function<bool(const std::vector<TypeHolder>&)> matches); | ||
| }; | ||
|
|
||
| /// \brief Constraint that all input types are decimal types and have the same scale. | ||
| ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1451,6 +1451,23 @@ struct CaseWhenFunction : ScalarFunction { | |
| if (auto kernel = DispatchExactImpl(this, *types)) return kernel; | ||
| return arrow::compute::detail::NoMatchingKernel(this, *types); | ||
| } | ||
|
|
||
| static std::shared_ptr<MatchConstraint> DecimalMatchConstraint() { | ||
| static auto constraint = | ||
| MatchConstraint::Make([](const std::vector<TypeHolder>& types) -> bool { | ||
| DCHECK_GE(types.size(), 3); | ||
| DCHECK(std::all_of(types.begin() + 1, types.end(), [](const TypeHolder& type) { | ||
| return is_decimal(type.id()); | ||
| })); | ||
| const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type); | ||
|
||
| return std::all_of( | ||
| types.begin() + 2, types.end(), [&ty1](const TypeHolder& type) { | ||
| const auto& ty = checked_cast<const DecimalType&>(*type.type); | ||
| return ty1.Equals(ty); | ||
| }); | ||
| }); | ||
| return constraint; | ||
| } | ||
| }; | ||
|
|
||
| // Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions | ||
|
|
@@ -2712,10 +2729,11 @@ struct ChooseFunction : ScalarFunction { | |
| }; | ||
|
|
||
| void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function, | ||
| detail::GetTypeId get_id, ArrayKernelExec exec) { | ||
| detail::GetTypeId get_id, ArrayKernelExec exec, | ||
| std::shared_ptr<MatchConstraint> constraint = nullptr) { | ||
| ScalarKernel kernel( | ||
| KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, LastType, | ||
| /*is_varargs=*/true), | ||
| /*is_varargs=*/true, std::move(constraint)), | ||
| exec); | ||
| if (is_fixed_width(get_id.id)) { | ||
| kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; | ||
|
|
@@ -2890,8 +2908,10 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { | |
| AddPrimitiveCaseWhenKernels(func, {boolean(), null(), float16()}); | ||
| AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY, | ||
| CaseWhenFunctor<FixedSizeBinaryType>::Exec); | ||
| AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec); | ||
| AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec); | ||
| AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec, | ||
| CaseWhenFunction::DecimalMatchConstraint()); | ||
| AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec, | ||
| CaseWhenFunction::DecimalMatchConstraint()); | ||
| AddBinaryCaseWhenKernels(func, BaseBinaryTypes()); | ||
| AddNestedCaseWhenKernels(func); | ||
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm excited 🤩