Skip to content

Commit a7703b9

Browse files
committed
support type_matcher for combination input types in decimal compare kernel
1 parent 8ca5f83 commit a7703b9

File tree

3 files changed

+67
-19
lines changed

3 files changed

+67
-19
lines changed

cpp/src/arrow/compute/kernel.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,11 @@ bool InputType::Matches(const Datum& value) const {
430430
return Matches(*value.type());
431431
}
432432

433+
bool InputType::Matches(const std::vector<TypeHolder>& types) const {
434+
DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_);
435+
return type_matcher_->Matches(types);
436+
}
437+
433438
const std::shared_ptr<DataType>& InputType::type() const {
434439
DCHECK_EQ(InputType::EXACT_TYPE, kind_);
435440
return type_;
@@ -505,9 +510,14 @@ bool KernelSignature::Equals(const KernelSignature& other) const {
505510
}
506511

507512
bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const {
513+
auto is_match_combination_types = [&](const InputType& in_type) {
514+
return in_type.kind() == InputType::USE_TYPE_MATCHER ? in_type.Matches(types) : true;
515+
};
516+
508517
if (is_varargs_) {
509518
for (size_t i = 0; i < types.size(); ++i) {
510-
if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(*types[i])) {
519+
const auto& in_type = in_types_[std::min(i, in_types_.size() - 1)];
520+
if (!in_type.Matches(*types[i]) || !is_match_combination_types(in_type)) {
511521
return false;
512522
}
513523
}
@@ -516,7 +526,7 @@ bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const
516526
return false;
517527
}
518528
for (size_t i = 0; i < in_types_.size(); ++i) {
519-
if (!in_types_[i].Matches(*types[i])) {
529+
if (!in_types_[i].Matches(*types[i]) || !is_match_combination_types(in_types_[i])) {
520530
return false;
521531
}
522532
}

cpp/src/arrow/compute/kernel.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ struct ARROW_EXPORT TypeMatcher {
109109
/// \brief Return true if this matcher accepts the data type.
110110
virtual bool Matches(const DataType& type) const = 0;
111111

112+
/// \brief Return true if this matcher accepts the combination types
113+
virtual bool Matches(const std::vector<TypeHolder>& types) const { return true; }
114+
112115
/// \brief A human-interpretable string representation of what the type
113116
/// matcher checks for, usable when printing KernelSignature or formatting
114117
/// error messages.
@@ -241,6 +244,10 @@ class ARROW_EXPORT InputType {
241244
/// \brief Return true if the type matches this InputType
242245
bool Matches(const DataType& type) const;
243246

247+
/// \brief Return true if the input combination types matches this
248+
/// InputType's type_matcher matched rules.
249+
bool Matches(const std::vector<TypeHolder>& types) const;
250+
244251
/// \brief The type matching rule that this InputType uses.
245252
Kind kind() const { return kind_; }
246253

cpp/src/arrow/compute/kernels/scalar_compare.cc

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -385,21 +385,52 @@ struct VarArgsCompareFunction : ScalarFunction {
385385
}
386386
};
387387

388-
Result<TypeHolder> ResolveDecimalCompareOutputType(KernelContext*,
389-
const std::vector<TypeHolder>& types) {
390-
// casted types should be same size decimals
391-
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
392-
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);
393-
DCHECK_EQ(left_type.id(), right_type.id());
394-
395-
// check the casted decimal scales according kAdd promotion rule
396-
const int32_t s1 = left_type.scale();
397-
const int32_t s2 = right_type.scale();
398-
if (s1 != s2) {
399-
return Status::Invalid("Comparison of two decimal ", "types s1 != s2. (", s1, s2,
400-
").");
401-
}
402-
return boolean();
388+
class DecimalTypesCompareMatcher : public TypeMatcher {
389+
public:
390+
explicit DecimalTypesCompareMatcher(std::shared_ptr<TypeMatcher> decimal_type_matcher)
391+
: decimal_type_matcher(std::move(decimal_type_matcher)) {}
392+
393+
bool Matches(const DataType& type) const override {
394+
return decimal_type_matcher->Matches(type);
395+
}
396+
397+
bool Matches(const std::vector<TypeHolder>& types) const override {
398+
DCHECK_EQ(types.size(), 2);
399+
if (!is_decimal(*types[0]) || !is_decimal(*types[1])) {
400+
return true;
401+
}
402+
403+
// Below match logic should only be executed when types are both decimal
404+
//
405+
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
406+
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);
407+
408+
// check the decimal types' scales according kAdd promotion rule
409+
const int32_t s1 = left_type.scale();
410+
const int32_t s2 = right_type.scale();
411+
if (s1 != s2) {
412+
return false;
413+
}
414+
return true;
415+
}
416+
417+
bool Equals(const TypeMatcher& other) const override {
418+
if (this == &other) {
419+
return true;
420+
}
421+
const auto* casted = dynamic_cast<const DecimalTypesCompareMatcher*>(&other);
422+
return casted != nullptr &&
423+
decimal_type_matcher->Equals(*casted->decimal_type_matcher);
424+
}
425+
426+
std::string ToString() const override { return "decimal-types-matcher"; }
427+
428+
private:
429+
std::shared_ptr<TypeMatcher> decimal_type_matcher;
430+
};
431+
432+
std::shared_ptr<TypeMatcher> DecimalTypesMatcher(Type::type type_id) {
433+
return std::make_shared<DecimalTypesCompareMatcher>(match::SameTypeId(type_id));
403434
}
404435

405436
template <typename Op>
@@ -450,9 +481,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
450481
}
451482

452483
for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
453-
OutputType out_type(ResolveDecimalCompareOutputType);
484+
InputType in_type(DecimalTypesMatcher(id));
454485
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
455-
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, out_type, std::move(exec)));
486+
DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
456487
}
457488

458489
{

0 commit comments

Comments
 (0)