Skip to content

Commit dadc21f

Browse files
authored
GH-47287: [C++][Compute] Add constraint for kernel signature matching and use it for binary decimal arithmetic kernels (#47297)
### Rationale for this change A rework of #40223 using a more systematic alternative. ### What changes are included in this PR? Introduce a structure `MatchConstraint` for applying extra (and optional) matching constraint for kernel signature matching, in additional to simply input type checks. Also implement two concrete `MatchConstraint`s for binary decimal arithmetic kernels, to suppress exact match even if the input types are OK, for example, by requiring all decimal must be of the same scale for `add` and `subtract`, and s1 >= s2 for `divide`. This should also be a fundamental enhancement to further resolve similar issues like: * #35843 * #39875 * #40911 * #41011 * #41336 (Haven't try each one of them. May do that if this PR gets merged.) ### Are these changes tested? UT included. ### Are there any user-facing changes? New public class `MatchConstraint`. * GitHub Issue: #47287 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent 8509ca4 commit dadc21f

File tree

9 files changed

+271
-25
lines changed

9 files changed

+271
-25
lines changed

cpp/src/arrow/compute/function.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,15 @@ Status Function::Validate() const {
410410
}
411411

412412
Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
413-
ArrayKernelExec exec, KernelInit init) {
413+
ArrayKernelExec exec, KernelInit init,
414+
std::shared_ptr<MatchConstraint> constraint) {
414415
RETURN_NOT_OK(CheckArity(in_types.size()));
415416

416417
if (arity_.is_varargs && in_types.size() != 1) {
417418
return Status::Invalid("VarArgs signatures must have exactly one input type");
418419
}
419-
auto sig =
420-
KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
420+
auto sig = KernelSignature::Make(std::move(in_types), std::move(out_type),
421+
arity_.is_varargs, std::move(constraint));
421422
kernels_.emplace_back(std::move(sig), exec, init);
422423
return Status::OK();
423424
}

cpp/src/arrow/compute/function.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl<ScalarKernel> {
308308
/// initialization, preallocation for fixed-width types, and default null
309309
/// handling (intersect validity bitmaps of inputs).
310310
Status AddKernel(std::vector<InputType> in_types, OutputType out_type,
311-
ArrayKernelExec exec, KernelInit init = NULLPTR);
311+
ArrayKernelExec exec, KernelInit init = NULLPTR,
312+
std::shared_ptr<MatchConstraint> constraint = NULLPTR);
312313

313314
/// \brief Add a kernel (function implementation). Returns error if the
314315
/// kernel's signature does not match the function's arity.

cpp/src/arrow/compute/kernel.cc

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,23 +475,93 @@ std::string OutputType::ToString() const {
475475
return "computed";
476476
}
477477

478+
// ----------------------------------------------------------------------
479+
// MatchConstraint
480+
481+
std::shared_ptr<MatchConstraint> MakeConstraint(
482+
std::function<bool(const std::vector<TypeHolder>&)> matches) {
483+
class FunctionMatchConstraint : public MatchConstraint {
484+
public:
485+
explicit FunctionMatchConstraint(
486+
std::function<bool(const std::vector<TypeHolder>&)> matches)
487+
: matches_(std::move(matches)) {}
488+
489+
bool Matches(const std::vector<TypeHolder>& types) const override {
490+
return matches_(types);
491+
}
492+
493+
private:
494+
std::function<bool(const std::vector<TypeHolder>&)> matches_;
495+
};
496+
497+
return std::make_shared<FunctionMatchConstraint>(std::move(matches));
498+
}
499+
500+
std::shared_ptr<MatchConstraint> DecimalsHaveSameScale() {
501+
class DecimalsHaveSameScaleConstraint : public MatchConstraint {
502+
public:
503+
bool Matches(const std::vector<TypeHolder>& types) const override {
504+
DCHECK_GE(types.size(), 2);
505+
DCHECK(std::all_of(types.begin(), types.end(),
506+
[](const TypeHolder& type) { return is_decimal(type.id()); }));
507+
const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
508+
auto s0 = ty0.scale();
509+
for (size_t i = 1; i < types.size(); ++i) {
510+
const auto& ty = checked_cast<const DecimalType&>(*types[i].type);
511+
if (ty.scale() != s0) {
512+
return false;
513+
}
514+
}
515+
return true;
516+
}
517+
};
518+
static auto instance = std::make_shared<DecimalsHaveSameScaleConstraint>();
519+
return instance;
520+
}
521+
522+
namespace {
523+
524+
template <typename Op>
525+
class BinaryDecimalScaleComparisonConstraint : public MatchConstraint {
526+
public:
527+
bool Matches(const std::vector<TypeHolder>& types) const override {
528+
DCHECK_EQ(types.size(), 2);
529+
DCHECK(is_decimal(types[0].id()));
530+
DCHECK(is_decimal(types[1].id()));
531+
const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
532+
const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
533+
return Op{}(ty0.scale(), ty1.scale());
534+
}
535+
};
536+
537+
} // namespace
538+
539+
std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2() {
540+
using BinaryDecimalScale1GeScale2Constraint =
541+
BinaryDecimalScaleComparisonConstraint<std::greater_equal<>>;
542+
static auto instance = std::make_shared<BinaryDecimalScale1GeScale2Constraint>();
543+
return instance;
544+
}
545+
478546
// ----------------------------------------------------------------------
479547
// KernelSignature
480548

481549
KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType out_type,
482-
bool is_varargs)
550+
bool is_varargs,
551+
std::shared_ptr<MatchConstraint> constraint)
483552
: in_types_(std::move(in_types)),
484553
out_type_(std::move(out_type)),
485554
is_varargs_(is_varargs),
555+
constraint_(std::move(constraint)),
486556
hash_code_(0) {
487557
DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1)));
488558
}
489559

490-
std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> in_types,
491-
OutputType out_type,
492-
bool is_varargs) {
560+
std::shared_ptr<KernelSignature> KernelSignature::Make(
561+
std::vector<InputType> in_types, OutputType out_type, bool is_varargs,
562+
std::shared_ptr<MatchConstraint> constraint) {
493563
return std::make_shared<KernelSignature>(std::move(in_types), std::move(out_type),
494-
is_varargs);
564+
is_varargs, std::move(constraint));
495565
}
496566

497567
bool KernelSignature::Equals(const KernelSignature& other) const {
@@ -526,6 +596,9 @@ bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const
526596
}
527597
}
528598
}
599+
if (constraint_ && !constraint_->Matches(types)) {
600+
return false;
601+
}
529602
return true;
530603
}
531604

cpp/src/arrow/compute/kernel.h

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,28 @@ class ARROW_EXPORT OutputType {
348348
Resolver resolver_ = NULLPTR;
349349
};
350350

351-
/// \brief Holds the input types and output type of the kernel.
351+
/// \brief Additional constraints to apply to the input types of a kernel when matching a
352+
/// specific kernel signature.
353+
class ARROW_EXPORT MatchConstraint {
354+
public:
355+
virtual ~MatchConstraint() = default;
356+
357+
/// \brief Return true if the input types satisfy the constraint.
358+
virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
359+
};
360+
361+
/// \brief Convenience function to create a MatchConstraint from a match function.
362+
ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
363+
std::function<bool(const std::vector<TypeHolder>&)> matches);
364+
365+
/// \brief Constraint that all input types are decimal types and have the same scale.
366+
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
367+
368+
/// \brief Constraint that all binary input types are decimal types and the first type's
369+
/// scale >= the second type's.
370+
ARROW_EXPORT std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2();
371+
372+
/// \brief Holds the input types, optional match constraint and output type of the kernel.
352373
///
353374
/// VarArgs functions with minimum N arguments should pass up to N input types to be
354375
/// used to validate the input types of a function invocation. The first N-1 types
@@ -357,15 +378,16 @@ class ARROW_EXPORT OutputType {
357378
class ARROW_EXPORT KernelSignature {
358379
public:
359380
KernelSignature(std::vector<InputType> in_types, OutputType out_type,
360-
bool is_varargs = false);
381+
bool is_varargs = false,
382+
std::shared_ptr<MatchConstraint> constraint = NULLPTR);
361383

362384
/// \brief Convenience ctor since make_shared can be awkward
363-
static std::shared_ptr<KernelSignature> Make(std::vector<InputType> in_types,
364-
OutputType out_type,
365-
bool is_varargs = false);
385+
static std::shared_ptr<KernelSignature> Make(
386+
std::vector<InputType> in_types, OutputType out_type, bool is_varargs = false,
387+
std::shared_ptr<MatchConstraint> constraint = NULLPTR);
366388

367389
/// \brief Return true if the signature is compatible with the list of input
368-
/// value descriptors.
390+
/// value descriptors and satisfies the match constraint, if any.
369391
bool MatchesInputs(const std::vector<TypeHolder>& types) const;
370392

371393
/// \brief Returns true if the input types of each signature are
@@ -401,6 +423,7 @@ class ARROW_EXPORT KernelSignature {
401423
std::vector<InputType> in_types_;
402424
OutputType out_type_;
403425
bool is_varargs_;
426+
std::shared_ptr<MatchConstraint> constraint_;
404427

405428
// For caching the hash code after it's computed the first time
406429
mutable uint64_t hash_code_;

cpp/src/arrow/compute/kernel_test.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,57 @@ TEST(OutputType, Resolve) {
307307
ASSERT_EQ(result, int32());
308308
}
309309

310+
// ----------------------------------------------------------------------
311+
// MatchConstraint
312+
313+
TEST(MatchConstraint, ConvenienceMaker) {
314+
{
315+
auto always_match =
316+
MakeConstraint([](const std::vector<TypeHolder>& types) { return true; });
317+
318+
ASSERT_TRUE(always_match->Matches({}));
319+
ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
320+
}
321+
322+
{
323+
auto always_false =
324+
MakeConstraint([](const std::vector<TypeHolder>& types) { return false; });
325+
326+
ASSERT_FALSE(always_false->Matches({}));
327+
ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));
328+
}
329+
}
330+
331+
TEST(MatchConstraint, DecimalsHaveSameScale) {
332+
auto c = DecimalsHaveSameScale();
333+
constexpr int32_t precision = 12, scale = 2;
334+
ASSERT_TRUE(c->Matches({decimal128(precision, scale), decimal128(precision, scale)}));
335+
ASSERT_TRUE(c->Matches({decimal128(precision, scale), decimal256(precision, scale)}));
336+
ASSERT_TRUE(c->Matches({decimal256(precision, scale), decimal128(precision, scale)}));
337+
ASSERT_TRUE(c->Matches({decimal256(precision, scale), decimal256(precision, scale)}));
338+
ASSERT_FALSE(
339+
c->Matches({decimal128(precision, scale), decimal128(precision, scale + 1)}));
340+
ASSERT_FALSE(c->Matches({decimal128(precision, scale), decimal128(precision, scale),
341+
decimal128(precision, scale + 1)}));
342+
}
343+
344+
TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) {
345+
auto c = BinaryDecimalScale1GeScale2();
346+
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
347+
ASSERT_TRUE(
348+
c->Matches({decimal128(precision, big_scale), decimal128(precision, small_scale)}));
349+
ASSERT_TRUE(
350+
c->Matches({decimal128(precision, big_scale), decimal256(precision, small_scale)}));
351+
ASSERT_TRUE(
352+
c->Matches({decimal256(precision, big_scale), decimal128(precision, small_scale)}));
353+
ASSERT_TRUE(
354+
c->Matches({decimal256(precision, big_scale), decimal256(precision, small_scale)}));
355+
ASSERT_TRUE(c->Matches(
356+
{decimal128(precision, small_scale), decimal128(precision, small_scale)}));
357+
ASSERT_FALSE(
358+
c->Matches({decimal128(precision, small_scale), decimal128(precision, big_scale)}));
359+
}
360+
310361
// ----------------------------------------------------------------------
311362
// KernelSignature
312363

@@ -419,6 +470,35 @@ TEST(KernelSignature, VarArgsMatchesInputs) {
419470
}
420471
}
421472

473+
TEST(KernelSignature, MatchesInputsWithConstraint) {
474+
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
475+
476+
auto small_scale_decimal = decimal128(precision, small_scale);
477+
auto big_scale_decimal = decimal128(precision, big_scale);
478+
479+
// No constraint.
480+
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, boolean());
481+
ASSERT_TRUE(
482+
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
483+
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
484+
ASSERT_TRUE(
485+
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
486+
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
487+
488+
for (auto constraint : {DecimalsHaveSameScale(), BinaryDecimalScale1GeScale2()}) {
489+
KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
490+
/*is_varargs=*/false, constraint);
491+
ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}),
492+
sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
493+
ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}),
494+
sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
495+
ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}),
496+
sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
497+
ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}),
498+
sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
499+
}
500+
}
501+
422502
TEST(KernelSignature, ToString) {
423503
std::vector<InputType> in_types = {InputType(int8()), InputType(Type::DECIMAL),
424504
InputType(utf8())};

cpp/src/arrow/compute/kernels/scalar_arithmetic.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,6 @@ Result<TypeHolder> ResolveDecimalAdditionOrSubtractionOutput(
598598
types,
599599
[](int32_t p1, int32_t s1, int32_t p2,
600600
int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
601-
if (s1 != s2) {
602-
return Status::Invalid("Addition or subtraction of two decimal ",
603-
"types scale1 != scale2. (", s1, s2, ").");
604-
}
605601
DCHECK_EQ(s1, s2);
606602
const int32_t scale = s1;
607603
const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1;
@@ -627,10 +623,6 @@ Result<TypeHolder> ResolveDecimalDivisionOutput(KernelContext*,
627623
types,
628624
[](int32_t p1, int32_t s1, int32_t p2,
629625
int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
630-
if (s1 < s2) {
631-
return Status::Invalid("Division of two decimal types scale1 < scale2. ", "(",
632-
s1, s2, ").");
633-
}
634626
DCHECK_GE(s1, s2);
635627
const int32_t scale = s1 - s2;
636628
const int32_t precision = p1;
@@ -669,13 +661,16 @@ void AddDecimalUnaryKernels(ScalarFunction* func) {
669661
template <typename Op>
670662
void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
671663
OutputType out_type(null());
664+
std::shared_ptr<MatchConstraint> constraint = nullptr;
672665
const std::string op = name.substr(0, name.find("_"));
673666
if (op == "add" || op == "subtract") {
674667
out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput);
668+
constraint = DecimalsHaveSameScale();
675669
} else if (op == "multiply") {
676670
out_type = OutputType(ResolveDecimalMultiplicationOutput);
677671
} else if (op == "divide") {
678672
out_type = OutputType(ResolveDecimalDivisionOutput);
673+
constraint = BinaryDecimalScale1GeScale2();
679674
} else {
680675
DCHECK(false);
681676
}
@@ -684,8 +679,10 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
684679
auto in_type256 = InputType(Type::DECIMAL256);
685680
auto exec128 = ScalarBinaryNotNullEqualTypes<Decimal128Type, Decimal128Type, Op>::Exec;
686681
auto exec256 = ScalarBinaryNotNullEqualTypes<Decimal256Type, Decimal256Type, Op>::Exec;
687-
DCHECK_OK(func->AddKernel({in_type128, in_type128}, out_type, exec128));
688-
DCHECK_OK(func->AddKernel({in_type256, in_type256}, out_type, exec256));
682+
DCHECK_OK(func->AddKernel({in_type128, in_type128}, out_type, exec128, /*init=*/nullptr,
683+
constraint));
684+
DCHECK_OK(func->AddKernel({in_type256, in_type256}, out_type, exec256, /*init=*/nullptr,
685+
constraint));
689686
}
690687

691688
template <typename Op>

0 commit comments

Comments
 (0)