Skip to content

Commit 2987165

Browse files
zanmato1984pitrou
andauthored
GH-41011: [C++][Compute] Fix the issue that comparison function could not handle decimal arguments with different scales (#47459)
### Rationale for this change We used to be not able to suppress the exact matching for decimal arguments with different scales, when a decimal comparison kernel who actually requires the scales to be the same. This caused issue like #41011. The "match constraint" introduced in #47297 is exactly for fixing issues like this, by simply adding a proper constraint. ### What changes are included in this PR? Added match constraint that requires all decimal inputs have the same scale (like for decimal addition and subtract). ### Are these changes tested? Yes. ### Are there any user-facing changes? None. * GitHub Issue: #41011 Lead-authored-by: Rossi Sun <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent cdea48e commit 2987165

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

cpp/src/arrow/compute/expression_test.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,11 @@ TEST(Expression, BindWithDecimalDivision) {
740740
}
741741

742742
TEST(Expression, BindWithImplicitCasts) {
743+
auto exciting_schema = schema(
744+
{field("i64", int64()), field("dec128_3_2", decimal128(3, 2)),
745+
field("dec128_4_2", decimal128(4, 2)), field("dec128_5_3", decimal128(5, 3)),
746+
field("dec256_3_2", decimal256(3, 2)), field("dec256_4_2", decimal256(4, 2)),
747+
field("dec256_5_3", decimal256(5, 3))});
743748
for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) {
744749
// cast arguments to common numeric type
745750
ExpectBindsTo(cmp(field_ref("i64"), field_ref("i32")),
@@ -800,6 +805,82 @@ TEST(Expression, BindWithImplicitCasts) {
800805
ExpectBindsTo(cmp(field_ref("i32"), literal(std::make_shared<DoubleScalar>(10.0))),
801806
cmp(cast(field_ref("i32"), float32()),
802807
literal(std::make_shared<FloatScalar>(10.0f))));
808+
809+
// decimal int
810+
ExpectBindsTo(cmp(field_ref("dec128_3_2"), field_ref("i64")),
811+
cmp(field_ref("dec128_3_2"), cast(field_ref("i64"), decimal128(21, 2))),
812+
/*bound_out=*/nullptr, *exciting_schema);
813+
ExpectBindsTo(cmp(field_ref("i64"), field_ref("dec128_3_2")),
814+
cmp(cast(field_ref("i64"), decimal128(21, 2)), field_ref("dec128_3_2")),
815+
/*bound_out=*/nullptr, *exciting_schema);
816+
817+
// decimal decimal with different widths different precisions but same scale
818+
ExpectBindsTo(
819+
cmp(field_ref("dec128_3_2"), field_ref("dec256_4_2")),
820+
cmp(cast(field_ref("dec128_3_2"), decimal256(3, 2)), field_ref("dec256_4_2")),
821+
/*bound_out=*/nullptr, *exciting_schema);
822+
ExpectBindsTo(
823+
cmp(field_ref("dec256_4_2"), field_ref("dec128_3_2")),
824+
cmp(field_ref("dec256_4_2"), cast(field_ref("dec128_3_2"), decimal256(3, 2))),
825+
/*bound_out=*/nullptr, *exciting_schema);
826+
ExpectBindsTo(
827+
cmp(field_ref("dec128_4_2"), field_ref("dec256_3_2")),
828+
cmp(cast(field_ref("dec128_4_2"), decimal256(4, 2)), field_ref("dec256_3_2")),
829+
/*bound_out=*/nullptr, *exciting_schema);
830+
ExpectBindsTo(
831+
cmp(field_ref("dec256_3_2"), field_ref("dec128_4_2")),
832+
cmp(field_ref("dec256_3_2"), cast(field_ref("dec128_4_2"), decimal256(4, 2))),
833+
/*bound_out=*/nullptr, *exciting_schema);
834+
835+
// decimal decimal with different widths different scales
836+
ExpectBindsTo(
837+
cmp(field_ref("dec128_3_2"), field_ref("dec256_5_3")),
838+
cmp(cast(field_ref("dec128_3_2"), decimal256(4, 3)), field_ref("dec256_5_3")),
839+
/*bound_out=*/nullptr, *exciting_schema);
840+
ExpectBindsTo(
841+
cmp(field_ref("dec256_5_3"), field_ref("dec128_3_2")),
842+
cmp(field_ref("dec256_5_3"), cast(field_ref("dec128_3_2"), decimal256(4, 3))),
843+
/*bound_out=*/nullptr, *exciting_schema);
844+
ExpectBindsTo(cmp(field_ref("dec128_5_3"), field_ref("dec256_3_2")),
845+
cmp(cast(field_ref("dec128_5_3"), decimal256(5, 3)),
846+
cast(field_ref("dec256_3_2"), decimal256(4, 3))),
847+
/*bound_out=*/nullptr, *exciting_schema);
848+
ExpectBindsTo(cmp(field_ref("dec256_3_2"), field_ref("dec128_5_3")),
849+
cmp(cast(field_ref("dec256_3_2"), decimal256(4, 3)),
850+
cast(field_ref("dec128_5_3"), decimal256(5, 3))),
851+
/*bound_out=*/nullptr, *exciting_schema);
852+
853+
// decimal decimal with same width same precision but different scales (no cast)
854+
ExpectBindsTo(cmp(field_ref("dec128_3_2"), field_ref("dec128_4_2")),
855+
cmp(field_ref("dec128_3_2"), field_ref("dec128_4_2")),
856+
/*bound_out=*/nullptr, *exciting_schema);
857+
ExpectBindsTo(cmp(field_ref("dec128_4_2"), field_ref("dec128_3_2")),
858+
cmp(field_ref("dec128_4_2"), field_ref("dec128_3_2")),
859+
/*bound_out=*/nullptr, *exciting_schema);
860+
ExpectBindsTo(cmp(field_ref("dec256_3_2"), field_ref("dec256_4_2")),
861+
cmp(field_ref("dec256_3_2"), field_ref("dec256_4_2")),
862+
/*bound_out=*/nullptr, *exciting_schema);
863+
ExpectBindsTo(cmp(field_ref("dec256_4_2"), field_ref("dec256_3_2")),
864+
cmp(field_ref("dec256_4_2"), field_ref("dec256_3_2")),
865+
/*bound_out=*/nullptr, *exciting_schema);
866+
867+
// decimal decimal with same width but different scales
868+
ExpectBindsTo(
869+
cmp(field_ref("dec128_3_2"), field_ref("dec128_5_3")),
870+
cmp(cast(field_ref("dec128_3_2"), decimal128(4, 3)), field_ref("dec128_5_3")),
871+
/*bound_out=*/nullptr, *exciting_schema);
872+
ExpectBindsTo(
873+
cmp(field_ref("dec128_5_3"), field_ref("dec128_3_2")),
874+
cmp(field_ref("dec128_5_3"), cast(field_ref("dec128_3_2"), decimal128(4, 3))),
875+
/*bound_out=*/nullptr, *exciting_schema);
876+
ExpectBindsTo(
877+
cmp(field_ref("dec256_3_2"), field_ref("dec256_5_3")),
878+
cmp(cast(field_ref("dec256_3_2"), decimal256(4, 3)), field_ref("dec256_5_3")),
879+
/*bound_out=*/nullptr, *exciting_schema);
880+
ExpectBindsTo(
881+
cmp(field_ref("dec256_5_3"), field_ref("dec256_3_2")),
882+
cmp(field_ref("dec256_5_3"), cast(field_ref("dec256_3_2"), decimal256(4, 3))),
883+
/*bound_out=*/nullptr, *exciting_schema);
803884
}
804885

805886
compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")};

cpp/src/arrow/compute/kernels/scalar_compare.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
436436

437437
for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
438438
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
439-
DCHECK_OK(
440-
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
439+
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec),
440+
/*init=*/nullptr, DecimalsHaveSameScale()));
441441
}
442442

443443
{

0 commit comments

Comments
 (0)