@@ -385,21 +385,52 @@ struct VarArgsCompareFunction : ScalarFunction {
385385 }
386386};
387387
388- Result<TypeHolder> ResolveDecimalCompareOutputType (KernelContext*,
389- const std::vector<TypeHolder>& types) {
390- // casted types should be same size decimals
391- const auto & left_type = checked_cast<const DecimalType&>(*types[0 ]);
392- const auto & right_type = checked_cast<const DecimalType&>(*types[1 ]);
393- DCHECK_EQ (left_type.id (), right_type.id ());
394-
395- // check the casted decimal scales according kAdd promotion rule
396- const int32_t s1 = left_type.scale ();
397- const int32_t s2 = right_type.scale ();
398- if (s1 != s2) {
399- return Status::Invalid (" Comparison of two decimal " , " types s1 != s2. (" , s1, s2,
400- " )." );
401- }
402- return boolean ();
388+ class DecimalTypesCompareMatcher : public TypeMatcher {
389+ public:
390+ explicit DecimalTypesCompareMatcher (std::shared_ptr<TypeMatcher> decimal_type_matcher)
391+ : decimal_type_matcher(std::move(decimal_type_matcher)) {}
392+
393+ bool Matches (const DataType& type) const override {
394+ return decimal_type_matcher->Matches (type);
395+ }
396+
397+ bool Matches (const std::vector<TypeHolder>& types) const override {
398+ DCHECK_EQ (types.size (), 2 );
399+ if (!is_decimal (*types[0 ]) || !is_decimal (*types[1 ])) {
400+ return true ;
401+ }
402+
403+ // Below match logic should only be executed when types are both decimal
404+ //
405+ const auto & left_type = checked_cast<const DecimalType&>(*types[0 ]);
406+ const auto & right_type = checked_cast<const DecimalType&>(*types[1 ]);
407+
408+ // check the decimal types' scales according kAdd promotion rule
409+ const int32_t s1 = left_type.scale ();
410+ const int32_t s2 = right_type.scale ();
411+ if (s1 != s2) {
412+ return false ;
413+ }
414+ return true ;
415+ }
416+
417+ bool Equals (const TypeMatcher& other) const override {
418+ if (this == &other) {
419+ return true ;
420+ }
421+ const auto * casted = dynamic_cast <const DecimalTypesCompareMatcher*>(&other);
422+ return casted != nullptr &&
423+ decimal_type_matcher->Equals (*casted->decimal_type_matcher );
424+ }
425+
426+ std::string ToString () const override { return " decimal-types-matcher" ; }
427+
428+ private:
429+ std::shared_ptr<TypeMatcher> decimal_type_matcher;
430+ };
431+
432+ std::shared_ptr<TypeMatcher> DecimalTypesMatcher (Type::type type_id) {
433+ return std::make_shared<DecimalTypesCompareMatcher>(match::SameTypeId (type_id));
403434}
404435
405436template <typename Op>
@@ -450,9 +481,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
450481 }
451482
452483 for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
453- OutputType out_type (ResolveDecimalCompareOutputType );
484+ InputType in_type ( DecimalTypesMatcher (id) );
454485 auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
455- DCHECK_OK (func->AddKernel ({InputType (id), InputType (id) }, out_type , std::move (exec)));
486+ DCHECK_OK (func->AddKernel ({in_type, in_type }, boolean () , std::move (exec)));
456487 }
457488
458489 {
0 commit comments