Skip to content

Commit deb73a7

Browse files
authored
GH-35957: [C++][Compute] Graceful error for decimal binary arithmetic and comparison instead of firing confusing assertion (#48639)
### Rationale for this change When dispatching binary arithmetic and comparison kernels, we do a special casting ahead for decimal arguments. If one argument is decimal and another is the type not castable (e.g., string) to decimal, an assertion fires. On the other hand, we have a graceful way to error on dispatch failure in the general kernel dispatching path after this special casting: ``` Function 'greater' has no kernel matching input types (string, double) ``` We want to unify the error path for decimal. ### What changes are included in this PR? Bypass the decimal casting early and not error out if we see the other argument is not castable to decimal, and let the subsequent general kernel dispatching path to handle the error gracefully. ### Are these changes tested? Test included. ### Are there any user-facing changes? None. * GitHub Issue: #35957 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Rossi Sun <[email protected]>
1 parent 727106f commit deb73a7

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

cpp/src/arrow/compute/kernels/codegen_internal.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,22 @@ TypeHolder CommonBinary(const TypeHolder* begin, size_t count) {
393393
return large_binary();
394394
}
395395

396+
bool CastableToDecimal(const DataType& type) {
397+
return is_numeric(type.id()) || is_decimal(type.id());
398+
}
399+
396400
Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<TypeHolder>* types) {
397401
const DataType& left_type = *(*types)[0];
398402
const DataType& right_type = *(*types)[1];
399403
DCHECK(is_decimal(left_type.id()) || is_decimal(right_type.id()));
400404

405+
if ((is_decimal(left_type.id()) && !CastableToDecimal(right_type)) ||
406+
(is_decimal(right_type.id()) && !CastableToDecimal(left_type))) {
407+
// If the other type is not castable to decimal, do not cast. The dispatch will
408+
// gracefully fail by kernel selection.
409+
return Status::OK();
410+
}
411+
401412
// decimal + float64 = float64
402413
// decimal + float32 is roughly float64 + float32 so we choose float64
403414
if (is_floating(left_type.id()) || is_floating(right_type.id())) {

cpp/src/arrow/compute/kernels/codegen_internal_test.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@ TEST(TestDispatchBest, CastBinaryDecimalArgs) {
5151
EXPECT_RAISES_WITH_MESSAGE_THAT(
5252
NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"),
5353
CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
54+
55+
// Non-castable -> unchanged
56+
for (const auto promotion :
57+
{DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide}) {
58+
for (const auto& args : std::vector<std::vector<TypeHolder>>{
59+
{decimal128(3, 2), boolean()},
60+
{boolean(), decimal128(3, 2)},
61+
{decimal128(3, 2), utf8()},
62+
{utf8(), decimal128(3, 2)},
63+
}) {
64+
auto args_copy = args;
65+
ASSERT_OK(CastBinaryDecimalArgs(promotion, &args_copy));
66+
AssertTypeEqual(*args_copy[0], *args[0]);
67+
AssertTypeEqual(*args_copy[1], *args[1]);
68+
}
69+
}
5470
}
5571

5672
TEST(TestDispatchBest, CastDecimalArgs) {

cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,29 @@ TEST_F(TestBinaryArithmeticDecimal, Power) {
24942494
}
24952495
}
24962496

2497+
TEST_F(TestBinaryArithmeticDecimal, ErrorOnNonCastable) {
2498+
for (const auto& name : {"add", "subtract", "multiply", "divide"}) {
2499+
for (const auto& suffix : {"", "_checked"}) {
2500+
auto func = std::string(name) + suffix;
2501+
SCOPED_TRACE(func);
2502+
for (const auto& dec_ty : PositiveScaleTypes()) {
2503+
SCOPED_TRACE(dec_ty->ToString());
2504+
auto dec_arr = ArrayFromJSON(dec_ty, R"([])");
2505+
for (const auto& other_ty : {boolean(), fixed_size_binary(42), utf8()}) {
2506+
SCOPED_TRACE(other_ty->ToString());
2507+
auto other_arr = ArrayFromJSON(other_ty, R"([])");
2508+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
2509+
::testing::HasSubstr("has no kernel matching"),
2510+
CallFunction(func, {dec_arr, other_arr}));
2511+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
2512+
::testing::HasSubstr("has no kernel matching"),
2513+
CallFunction(func, {other_arr, dec_arr}));
2514+
}
2515+
}
2516+
}
2517+
}
2518+
}
2519+
24972520
TYPED_TEST(TestBinaryArithmeticIntegral, ShiftLeft) {
24982521
for (auto check_overflow : {false, true}) {
24992522
this->SetOverflowCheck(check_overflow);

cpp/src/arrow/compute/kernels/scalar_compare_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,26 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) {
681681
}
682682
}
683683

684+
TYPED_TEST(TestCompareDecimal, ErrorOnNonCastable) {
685+
auto dec_ty = std::make_shared<TypeParam>(3, 2);
686+
auto dec_arr = ArrayFromJSON(dec_ty, R"([])");
687+
688+
for (const auto& func :
689+
{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) {
690+
SCOPED_TRACE(func);
691+
for (const auto& other_ty : {boolean(), fixed_size_binary(42), utf8()}) {
692+
SCOPED_TRACE(other_ty->ToString());
693+
auto other_arr = ArrayFromJSON(other_ty, R"([])");
694+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
695+
::testing::HasSubstr("has no kernel matching"),
696+
CallFunction(func, {dec_arr, other_arr}));
697+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
698+
::testing::HasSubstr("has no kernel matching"),
699+
CallFunction(func, {other_arr, dec_arr}));
700+
}
701+
}
702+
}
703+
684704
// Helper to organize tests for fixed size binary comparisons
685705
struct CompareCase {
686706
std::shared_ptr<DataType> lhs_type;

0 commit comments

Comments
 (0)