Skip to content

Commit 8ca5f83

Browse files
committed
fix decimal compare wrong in compute expression
1 parent 831b94a commit 8ca5f83

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,46 @@ void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) {
845845
}
846846
}
847847

848+
TEST(Expression, ExecuteCallWithDecimalComparisonOps) {
849+
// GH-41011, make sure the decimal's comparison operations are casted
850+
// in expression bind and make correct results in expression execute
851+
ExpectExecute(
852+
call("not_equal", {field_ref("d1"), field_ref("d2")}),
853+
ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}),
854+
R"([
855+
{"d1": "40", "d2": "4.0"},
856+
{"d1": "20", "d2": "2.0"}
857+
])"));
858+
859+
ExpectExecute(
860+
call("less", {field_ref("d1"), field_ref("d2")}),
861+
ArrayFromJSON(struct_({field("d1", decimal(2, 1)), field("d2", decimal(2, 0))}),
862+
R"([
863+
{"d1": "4.0", "d2": "40"},
864+
{"d1": "2.0", "d2": "20"}
865+
])"));
866+
867+
for (std::string fname : {"less_equal", "equal"}) {
868+
ExpectExecute(
869+
call(fname, {field_ref("d1"), field_ref("d2")}),
870+
ArrayFromJSON(struct_({field("d1", decimal(3, 2)), field("d2", decimal(2, 1))}),
871+
R"([
872+
{"d1": "3.10", "d2": "3.1"},
873+
{"d1": "2.10", "d2": "2.1"}
874+
])"));
875+
}
876+
877+
for (std::string fname : {"greater_equal", "greater"}) {
878+
ExpectExecute(
879+
call(fname, {field_ref("d1"), field_ref("d2")}),
880+
ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}),
881+
R"([
882+
{"d1": "4", "d2": "3.0"},
883+
{"d1": "3", "d2": "2.0"}
884+
])"));
885+
}
886+
}
887+
848888
TEST(Expression, ExecuteCall) {
849889
ExpectExecute(add(field_ref("a"), literal(3.5)),
850890
ArrayFromJSON(struct_({field("a", float64())}), R"([

cpp/src/arrow/compute/kernels/scalar_compare.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,23 @@ struct VarArgsCompareFunction : ScalarFunction {
385385
}
386386
};
387387

388+
Result<TypeHolder> ResolveDecimalCompareOutputType(KernelContext*,
389+
const std::vector<TypeHolder>& types) {
390+
// casted types should be same size decimals
391+
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
392+
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);
393+
DCHECK_EQ(left_type.id(), right_type.id());
394+
395+
// check the casted decimal scales according kAdd promotion rule
396+
const int32_t s1 = left_type.scale();
397+
const int32_t s2 = right_type.scale();
398+
if (s1 != s2) {
399+
return Status::Invalid("Comparison of two decimal ", "types s1 != s2. (", s1, s2,
400+
").");
401+
}
402+
return boolean();
403+
}
404+
388405
template <typename Op>
389406
std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDoc doc) {
390407
auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), std::move(doc));
@@ -433,9 +450,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
433450
}
434451

435452
for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
453+
OutputType out_type(ResolveDecimalCompareOutputType);
436454
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
437-
DCHECK_OK(
438-
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
455+
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, out_type, std::move(exec)));
439456
}
440457

441458
{

0 commit comments

Comments
 (0)