Skip to content

Commit 62eab1d

Browse files
committed
Fix the issue that comparison function could not handle decimal arguments with different scale
1 parent 6f6138b commit 62eab1d

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,10 @@ TEST(Expression, BindWithDecimalArithmeticOps) {
648648
}
649649

650650
TEST(Expression, BindWithImplicitCasts) {
651+
auto exciting_schema = schema(
652+
{field("i64", int64()), field("dec128_3_2", decimal128(3, 2)),
653+
field("dec128_5_3", decimal128(5, 3)), field("dec256_3_2", decimal256(3, 2)),
654+
field("dec256_5_3", decimal256(5, 3))});
651655
for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) {
652656
// cast arguments to common numeric type
653657
ExpectBindsTo(cmp(field_ref("i64"), field_ref("i32")),
@@ -708,6 +712,42 @@ TEST(Expression, BindWithImplicitCasts) {
708712
ExpectBindsTo(cmp(field_ref("i32"), literal(std::make_shared<DoubleScalar>(10.0))),
709713
cmp(cast(field_ref("i32"), float32()),
710714
literal(std::make_shared<FloatScalar>(10.0f))));
715+
716+
// decimal int
717+
ExpectBindsTo(cmp(field_ref("dec128_3_2"), field_ref("i64")),
718+
cmp(field_ref("dec128_3_2"), cast(field_ref("i64"), decimal128(21, 2))),
719+
/*bound_out=*/nullptr, *exciting_schema);
720+
ExpectBindsTo(cmp(field_ref("i64"), field_ref("dec128_3_2")),
721+
cmp(cast(field_ref("i64"), decimal128(21, 2)), field_ref("dec128_3_2")),
722+
/*bound_out=*/nullptr, *exciting_schema);
723+
724+
// decimal128 decimal256 with different scales
725+
ExpectBindsTo(
726+
cmp(field_ref("dec128_3_2"), field_ref("dec256_5_3")),
727+
cmp(cast(field_ref("dec128_3_2"), decimal256(4, 3)), field_ref("dec256_5_3")),
728+
/*bound_out=*/nullptr, *exciting_schema);
729+
ExpectBindsTo(
730+
cmp(field_ref("dec256_5_3"), field_ref("dec128_3_2")),
731+
cmp(field_ref("dec256_5_3"), cast(field_ref("dec128_3_2"), decimal256(4, 3))),
732+
/*bound_out=*/nullptr, *exciting_schema);
733+
ExpectBindsTo(cmp(field_ref("dec128_5_3"), field_ref("dec256_3_2")),
734+
cmp(cast(field_ref("dec128_5_3"), decimal256(5, 3)),
735+
cast(field_ref("dec256_3_2"), decimal256(4, 3))),
736+
/*bound_out=*/nullptr, *exciting_schema);
737+
ExpectBindsTo(cmp(field_ref("dec256_3_2"), field_ref("dec128_5_3")),
738+
cmp(cast(field_ref("dec256_3_2"), decimal256(4, 3)),
739+
cast(field_ref("dec128_5_3"), decimal256(5, 3))),
740+
/*bound_out=*/nullptr, *exciting_schema);
741+
742+
// decimal decimal with different scales
743+
ExpectBindsTo(
744+
cmp(field_ref("dec128_3_2"), field_ref("dec128_5_3")),
745+
cmp(cast(field_ref("dec128_3_2"), decimal128(4, 3)), field_ref("dec128_5_3")),
746+
/*bound_out=*/nullptr, *exciting_schema);
747+
ExpectBindsTo(
748+
cmp(field_ref("dec128_5_3"), field_ref("dec128_3_2")),
749+
cmp(field_ref("dec128_5_3"), cast(field_ref("dec128_3_2"), decimal128(4, 3))),
750+
/*bound_out=*/nullptr, *exciting_schema);
711751
}
712752

713753
compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")};

cpp/src/arrow/compute/kernels/scalar_compare.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
436436

437437
for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
438438
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
439-
DCHECK_OK(
440-
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
439+
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec),
440+
/*init=*/nullptr, DecimalsHaveSameScale()));
441441
}
442442

443443
{

0 commit comments

Comments
 (0)