diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 30bd882b2c03..85e7005c3e44 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -845,6 +845,46 @@ void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { } } +TEST(Expression, ExecuteCallWithDecimalComparisonOps) { + // GH-41011, make sure the decimal's comparison operations are casted + // in expression bind and make correct results in expression execute + ExpectExecute( + call("not_equal", {field_ref("d1"), field_ref("d2")}), + ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}), + R"([ + {"d1": "40", "d2": "4.0"}, + {"d1": "20", "d2": "2.0"} + ])")); + + ExpectExecute( + call("less", {field_ref("d1"), field_ref("d2")}), + ArrayFromJSON(struct_({field("d1", decimal(2, 1)), field("d2", decimal(2, 0))}), + R"([ + {"d1": "4.0", "d2": "40"}, + {"d1": "2.0", "d2": "20"} + ])")); + + for (std::string fname : {"less_equal", "equal"}) { + ExpectExecute( + call(fname, {field_ref("d1"), field_ref("d2")}), + ArrayFromJSON(struct_({field("d1", decimal(3, 2)), field("d2", decimal(2, 1))}), + R"([ + {"d1": "3.10", "d2": "3.1"}, + {"d1": "2.10", "d2": "2.1"} + ])")); + } + + for (std::string fname : {"greater_equal", "greater"}) { + ExpectExecute( + call(fname, {field_ref("d1"), field_ref("d2")}), + ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}), + R"([ + {"d1": "4", "d2": "3.0"}, + {"d1": "3", "d2": "2.0"} + ])")); + } +} + TEST(Expression, ExecuteCall) { ExpectExecute(add(field_ref("a"), literal(3.5)), ArrayFromJSON(struct_({field("a", float64())}), R"([ diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index fd554ba3d83c..46b679993a79 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -430,6 +430,11 @@ bool InputType::Matches(const Datum& value) const { return Matches(*value.type()); } +bool InputType::Matches(const std::vector& types) const { + DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_); + return type_matcher_->Matches(types); +} + const std::shared_ptr& InputType::type() const { DCHECK_EQ(InputType::EXACT_TYPE, kind_); return type_; @@ -505,9 +510,14 @@ bool KernelSignature::Equals(const KernelSignature& other) const { } bool KernelSignature::MatchesInputs(const std::vector& types) const { + auto is_match_combination_types = [&](const InputType& in_type) { + return in_type.kind() == InputType::USE_TYPE_MATCHER ? in_type.Matches(types) : true; + }; + if (is_varargs_) { for (size_t i = 0; i < types.size(); ++i) { - if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(*types[i])) { + const auto& in_type = in_types_[std::min(i, in_types_.size() - 1)]; + if (!in_type.Matches(*types[i]) || !is_match_combination_types(in_type)) { return false; } } @@ -516,7 +526,7 @@ bool KernelSignature::MatchesInputs(const std::vector& types) const return false; } for (size_t i = 0; i < in_types_.size(); ++i) { - if (!in_types_[i].Matches(*types[i])) { + if (!in_types_[i].Matches(*types[i]) || !is_match_combination_types(in_types_[i])) { return false; } } diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 1adb3e96c97c..251b3577fb0b 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -109,6 +109,9 @@ struct ARROW_EXPORT TypeMatcher { /// \brief Return true if this matcher accepts the data type. virtual bool Matches(const DataType& type) const = 0; + /// \brief Return true if this matcher accepts the combination types + virtual bool Matches(const std::vector& types) const { return true; } + /// \brief A human-interpretable string representation of what the type /// matcher checks for, usable when printing KernelSignature or formatting /// error messages. @@ -241,6 +244,10 @@ class ARROW_EXPORT InputType { /// \brief Return true if the type matches this InputType bool Matches(const DataType& type) const; + /// \brief Return true if the input combination types matches this + /// InputType's type_matcher matched rules. + bool Matches(const std::vector& types) const; + /// \brief The type matching rule that this InputType uses. Kind kind() const { return kind_; } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index daf8ed76d628..b78196de37a0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -385,6 +385,54 @@ struct VarArgsCompareFunction : ScalarFunction { } }; +class DecimalTypesCompareMatcher : public TypeMatcher { + public: + explicit DecimalTypesCompareMatcher(std::shared_ptr decimal_type_matcher) + : decimal_type_matcher(std::move(decimal_type_matcher)) {} + + bool Matches(const DataType& type) const override { + return decimal_type_matcher->Matches(type); + } + + bool Matches(const std::vector& types) const override { + DCHECK_EQ(types.size(), 2); + if (!is_decimal(*types[0]) || !is_decimal(*types[1])) { + return true; + } + + // Below match logic should only be executed when types are both decimal + // + const auto& left_type = checked_cast(*types[0]); + const auto& right_type = checked_cast(*types[1]); + + // check the decimal types' scales according kAdd promotion rule + const int32_t s1 = left_type.scale(); + const int32_t s2 = right_type.scale(); + if (s1 != s2) { + return false; + } + return true; + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + const auto* casted = dynamic_cast(&other); + return casted != nullptr && + decimal_type_matcher->Equals(*casted->decimal_type_matcher); + } + + std::string ToString() const override { return "decimal-types-matcher"; } + + private: + std::shared_ptr decimal_type_matcher; +}; + +std::shared_ptr DecimalTypesMatcher(Type::type type_id) { + return std::make_shared(match::SameTypeId(type_id)); +} + template std::shared_ptr MakeCompareFunction(std::string name, FunctionDoc doc) { auto func = std::make_shared(name, Arity::Binary(), std::move(doc)); @@ -433,9 +481,9 @@ std::shared_ptr MakeCompareFunction(std::string name, FunctionDo } for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) { + InputType in_type(DecimalTypesMatcher(id)); auto exec = GenerateDecimal(id); - DCHECK_OK( - func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec))); + DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); } {