Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,19 +634,72 @@ TEST(Expression, BindWithAliasCasts) {
}

TEST(Expression, BindWithDecimalArithmeticOps) {
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")});
EXPECT_FALSE(expr.IsBound());

static const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
for (auto s : scales) {
auto schema = arrow::schema(
{field("d1", decimal256(30, s.first)), field("d2", decimal256(20, s.second))});
ExpectBindsTo(expr, no_change, &expr, *schema);
static const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};

for (std::string suffix : {"", "_checked"}) {
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
std::string name = arith_op + suffix;
SCOPED_TRACE(name);

for (auto s : scales) {
auto schema = arrow::schema({field("d1", decimal256(30, s.first)),
field("d2", decimal256(20, s.second))});
auto expr = call(name, {field_ref("d1"), field_ref("d2")});
EXPECT_FALSE(expr.IsBound());
ExpectBindsTo(expr, no_change, &expr, *schema);
}
}
}
}

TEST(Expression, BindWithDecimalDivision) {
auto expect_decimal_division_type = [](std::string name,
std::shared_ptr<DataType> dividend,
std::shared_ptr<DataType> divisor,
std::shared_ptr<DataType> expected) {
auto schema = arrow::schema({field("dividend", dividend), field("divisor", divisor)});
auto expr = call(name, {field_ref("dividend"), field_ref("divisor")});
ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*schema));
EXPECT_TRUE(bound.IsBound());
EXPECT_TRUE(bound.type()->Equals(expected));
};

for (std::string name : {"divide", "divide_checked"}) {
SCOPED_TRACE(name);

expect_decimal_division_type(name, int64(), arrow::decimal128(1, 0),
decimal128(23, 4));
expect_decimal_division_type(name, arrow::decimal128(1, 0), int64(),
decimal128(21, 20));

expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 1),
decimal128(6, 4));
expect_decimal_division_type(name, decimal256(2, 1), decimal256(2, 1),
decimal256(6, 4));
expect_decimal_division_type(name, decimal128(2, 1), decimal256(2, 1),
decimal256(6, 4));
expect_decimal_division_type(name, decimal256(2, 1), decimal128(2, 1),
decimal256(6, 4));

expect_decimal_division_type(name, decimal128(2, 0), decimal128(2, 1),
decimal128(7, 4));
expect_decimal_division_type(name, decimal128(2, 1), decimal128(2, 0),
decimal128(5, 4));

// GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type.
// decimal128(3, 2) / decimal128(15, 2)
// -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16)
expect_decimal_division_type(name, decimal128(3, 2), decimal128(15, 2),
decimal128(19, 16));

// GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type.
// decimal128(7, 2) / decimal128(6, 1)
// -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8)
expect_decimal_division_type(name, decimal128(7, 2), decimal128(6, 1),
decimal128(14, 8));
}
}

TEST(Expression, BindWithImplicitCasts) {
for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) {
// cast arguments to common numeric type
Expand Down
24 changes: 0 additions & 24 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,30 +519,6 @@ std::shared_ptr<MatchConstraint> DecimalsHaveSameScale() {
return instance;
}

namespace {

template <typename Op>
class BinaryDecimalScaleComparisonConstraint : public MatchConstraint {
public:
bool Matches(const std::vector<TypeHolder>& types) const override {
DCHECK_EQ(types.size(), 2);
DCHECK(is_decimal(types[0].id()));
DCHECK(is_decimal(types[1].id()));
const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
return Op{}(ty0.scale(), ty1.scale());
}
};

} // namespace

std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2() {
using BinaryDecimalScale1GeScale2Constraint =
BinaryDecimalScaleComparisonConstraint<std::greater_equal<>>;
static auto instance = std::make_shared<BinaryDecimalScale1GeScale2Constraint>();
return instance;
}

// ----------------------------------------------------------------------
// KernelSignature

Expand Down
4 changes: 0 additions & 4 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,6 @@ ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
/// \brief Constraint that all input types are decimal types and have the same scale.
ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();

/// \brief Constraint that all binary input types are decimal types and the first type's
/// scale >= the second type's.
ARROW_EXPORT std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2();

/// \brief Holds the input types, optional match constraint and output type of the kernel.
///
/// VarArgs functions with minimum N arguments should pass up to N input types to be
Expand Down
66 changes: 24 additions & 42 deletions cpp/src/arrow/compute/kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,23 +341,6 @@ TEST(MatchConstraint, DecimalsHaveSameScale) {
decimal128(precision, scale + 1)}));
}

TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) {
auto c = BinaryDecimalScale1GeScale2();
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
ASSERT_TRUE(
c->Matches({decimal128(precision, big_scale), decimal128(precision, small_scale)}));
ASSERT_TRUE(
c->Matches({decimal128(precision, big_scale), decimal256(precision, small_scale)}));
ASSERT_TRUE(
c->Matches({decimal256(precision, big_scale), decimal128(precision, small_scale)}));
ASSERT_TRUE(
c->Matches({decimal256(precision, big_scale), decimal256(precision, small_scale)}));
ASSERT_TRUE(c->Matches(
{decimal128(precision, small_scale), decimal128(precision, small_scale)}));
ASSERT_FALSE(
c->Matches({decimal128(precision, small_scale), decimal128(precision, big_scale)}));
}

// ----------------------------------------------------------------------
// KernelSignature

Expand Down Expand Up @@ -471,31 +454,30 @@ TEST(KernelSignature, VarArgsMatchesInputs) {
}

TEST(KernelSignature, MatchesInputsWithConstraint) {
constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;

auto small_scale_decimal = decimal128(precision, small_scale);
auto big_scale_decimal = decimal128(precision, big_scale);

// No constraint.
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, boolean());
ASSERT_TRUE(
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));
ASSERT_TRUE(
sig_no_constraint.MatchesInputs({small_scale_decimal, small_scale_decimal}));
ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, big_scale_decimal}));

for (auto constraint : {DecimalsHaveSameScale(), BinaryDecimalScale1GeScale2()}) {
KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
/*is_varargs=*/false, constraint);
ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}),
sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}),
sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}),
sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}),
sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
auto precisions = {12, 22}, scales = {2, 3};
for (auto p1 : precisions) {
for (auto s1 : scales) {
auto d1 = decimal128(p1, s1);
for (auto p2 : precisions) {
for (auto s2 : scales) {
auto d2 = decimal128(p2, s2);

{
// No constraint.
KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128},
boolean());
ASSERT_TRUE(sig_no_constraint.MatchesInputs({d1, d2}));
}

{
// All decimal types must have the same scale.
KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
/*is_varargs=*/false, DecimalsHaveSameScale());
ASSERT_EQ(sig.MatchesInputs({d1, d2}), s1 == s2);
}
}
}
}
}
}

Expand Down
12 changes: 11 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
out_type = OutputType(ResolveDecimalMultiplicationOutput);
} else if (op == "divide") {
out_type = OutputType(ResolveDecimalDivisionOutput);
constraint = BinaryDecimalScale1GeScale2();
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need this constraint to suppress the exact matching as this is now done via overridden DispatchExact.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a DecimalsHaveSameScaleAndPrecision? (or full type equality, which is exactly equivalent here)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the tricky part about decimal division - there is no "exact match" at all.

By definition we ALWAYS promote the dividend no matter their (p, s) are. For example, decimal(5, 1) / decimal(5, 1) = decimal(11, 6).

As long as we allow any exact match, the promotion won't happen.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, {decimal(5, 1), decimal(5, 1)} looks like an exact match in this example. The result type is unrelated to this.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you mean the dividend gets promoted to decimal(11, 6)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you mean the dividend gets promoted to decimal(11, 6)?

Exactly. Except that it is actually promoted to decimal(11, 7) but you get the idea.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, thanks. Perhaps the PR description can be clearer about this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is how we obey the resulting type rule we claim - promoting the dividend.

That said, there is an alternative though - as you implied in your previous comment

looks like an exact match in this example. The result type is unrelated to this.

This is also explained in my PR description approach 2. I didn't take that approach because that would require the promotion to happen during the underlying division for each individual value in the array. Can be cumbersome in terms of both coding and performance.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for the explanation!

} else {
DCHECK(false);
}
Expand Down Expand Up @@ -727,6 +726,17 @@ ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id)
struct ArithmeticFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Result<const Kernel*> DispatchExact(
const std::vector<TypeHolder>& types) const override {
if ((name_ == "divide" || name_ == "divide_checked") && HasDecimal(types)) {
// Decimal division ALWAYS scales up the dividend, so there will NEVER be an exact
// match.
return arrow::compute::detail::NoMatchingKernel(this, types);
}

return ScalarFunction::DispatchExact(types);
}

Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override {
RETURN_NOT_OK(CheckArity(types->size()));

Expand Down
68 changes: 44 additions & 24 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1942,14 +1942,14 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchExact) {
name += suffix;
ARROW_SCOPED_TRACE(name);

CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 1)});
CheckDispatchExactFails(name, {decimal128(3, 1), decimal128(2, 1)});
CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 0)});
CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});

CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 1)});
CheckDispatchExactFails(name, {decimal256(3, 1), decimal256(2, 1)});
CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 0)});
CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
}
}
Expand Down Expand Up @@ -2025,24 +2025,36 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchBest) {
name += suffix;
SCOPED_TRACE(name);

CheckDispatchBest(name, {int64(), decimal128(1, 0)},
{decimal128(23, 4), decimal128(1, 0)});
CheckDispatchBest(name, {decimal128(1, 0), int64()},
{decimal128(21, 20), decimal128(19, 0)});

CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
{decimal128(6, 5), decimal128(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
{decimal256(6, 5), decimal256(2, 1)});
CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
{decimal256(6, 5), decimal256(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
{decimal256(6, 5), decimal256(2, 1)});

CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
{decimal128(7, 5), decimal128(2, 1)});
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
{decimal128(5, 4), decimal128(2, 0)});
CheckDispatchBestWithCastedTypes(name, {int64(), decimal128(1, 0)},
{decimal128(23, 4), decimal128(1, 0)});
CheckDispatchBestWithCastedTypes(name, {decimal128(1, 0), int64()},
{decimal128(21, 20), decimal128(19, 0)});

CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 1)},
{decimal128(6, 5), decimal128(2, 1)});
CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal256(2, 1)},
{decimal256(6, 5), decimal256(2, 1)});
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal256(2, 1)},
{decimal256(6, 5), decimal256(2, 1)});
CheckDispatchBestWithCastedTypes(name, {decimal256(2, 1), decimal128(2, 1)},
{decimal256(6, 5), decimal256(2, 1)});

CheckDispatchBestWithCastedTypes(name, {decimal128(2, 0), decimal128(2, 1)},
{decimal128(7, 5), decimal128(2, 1)});
CheckDispatchBestWithCastedTypes(name, {decimal128(2, 1), decimal128(2, 0)},
{decimal128(5, 4), decimal128(2, 0)});

// GH-39875: Expression call to decimal(3 ,2) / decimal(15, 2) wrong result type.
// decimal128(3, 2) / decimal128(15, 2)
// -> decimal128(19, 18) / decimal128(15, 2) = decimal128(19, 16)
CheckDispatchBestWithCastedTypes(name, {decimal128(3, 2), decimal128(15, 2)},
{decimal128(19, 18), decimal128(15, 2)});

// GH-40911: Expression call to decimal(7 ,2) / decimal(6, 1) wrong result type.
// decimal128(7, 2) / decimal128(6, 1)
// -> decimal128(14, 9) / decimal128(6, 1) = decimal128(14, 8)
CheckDispatchBestWithCastedTypes(name, {decimal128(7, 2), decimal128(6, 1)},
{decimal128(14, 9), decimal128(6, 1)});
}
}
for (std::string name : {"atan2", "logb", "logb_checked", "power", "power_checked"}) {
Expand Down Expand Up @@ -2332,6 +2344,14 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) {
CheckScalarBinary("divide", left, right, expected);
}

// decimal(p1, s1) decimal(p2, s2) where s1 < s2
{
auto left = ScalarFromJSON(decimal128(6, 5), R"("2.71828")");
auto right = ScalarFromJSON(decimal128(7, 6), R"("3.141592")");
auto expected = ScalarFromJSON(decimal128(14, 7), R"("0.8652555")");
CheckScalarBinary("divide", left, right, expected);
}

// decimal128 decimal256
{
auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")");
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/kernels/test_util_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,18 @@ void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> original_v
}
}

void CheckDispatchBestWithCastedTypes(std::string func_name,
std::vector<TypeHolder> values,
const std::vector<TypeHolder>& expected_values) {
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
ASSERT_OK_AND_ASSIGN(auto kernel, function->DispatchBest(&values));
ASSERT_NE(kernel, nullptr);
EXPECT_EQ(values.size(), expected_values.size());
for (size_t i = 0; i < values.size(); i++) {
AssertTypeEqual(*values[i], *expected_values[i]);
}
}

void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> types) {
ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
ASSERT_NOT_OK(function->DispatchExact(types));
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/test_util_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types);
void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> types,
std::vector<TypeHolder> exact_types);

// Check that DispatchBest on a given function yields a valid Kernel and casts the input
// types as expected
void CheckDispatchBestWithCastedTypes(std::string func_name,
std::vector<TypeHolder> types,
const std::vector<TypeHolder>& expected_types);

// Check that function fails to produce a Kernel via DispatchExact for the set of types
void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> types);

Expand Down
Loading