Skip to content

Commit 0118b9e

Browse files
committed
feat: Add checked_multiply for Spark decimal arithmetic
Add checked decimal multiply function that throws on overflow instead of returning null. This is needed for Spark's ANSI mode where arithmetic overflow should raise an error rather than silently produce null.
1 parent 1aee3b4 commit 0118b9e

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

velox/docs/functions/spark/decimal.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ Arithmetic Functions
143143
Division by zero or overflow results in an error.
144144
Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to true.
145145

146+
.. spark:function:: multiply(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3)
147+
148+
Returns the result of multiplying ``x`` and ``y``. The result type is determined
149+
by the precision and scale computation rules described above.
150+
Returns NULL when the result overflows.
151+
Corresponds to Spark's operator ``*`` with ``spark.sql.ansi.enabled`` set to false. ::
152+
153+
SELECT CAST(1.1 as DECIMAL(3, 1)) * CAST(2.0 as DECIMAL(3, 1)); -- 2.20
154+
SELECT CAST('99999999999999999999999999999999999999' as DECIMAL(38, 0)) * CAST(10 as DECIMAL(38, 0)); -- NULL
155+
156+
.. spark:function:: checked_multiply(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3)
157+
158+
Returns the result of multiplying ``x`` and ``y``. The result type is determined
159+
by the precision and scale computation rules described above.
160+
Throws an error when the result overflows.
161+
Corresponds to Spark's operator ``*`` with ``spark.sql.ansi.enabled`` set to true.
162+
146163
Decimal Functions
147164
-----------------
148165
.. spark:function:: ceil(x: decimal(p, s)) -> r: decimal(pr, 0)

velox/functions/sparksql/DecimalArithmetic.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,21 @@ struct DecimalMultiplyFunction {
442442
int32_t deltaScale_;
443443
};
444444

445+
// Decimal multiply function that returns error on overflow.
446+
template <typename TExec, bool allowPrecisionLoss>
447+
struct CheckedDecimalMultiplyFunction
448+
: DecimalMultiplyFunction<TExec, allowPrecisionLoss> {
449+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
450+
451+
template <typename R, typename A, typename B>
452+
Status call(R& out, const A& a, const B& b) {
453+
bool valid = DecimalMultiplyFunction<TExec, allowPrecisionLoss>::
454+
template call<R, A, B>(out, a, b);
455+
VELOX_USER_RETURN(!valid, "Decimal overflow in multiply");
456+
return Status::OK();
457+
}
458+
};
459+
445460
template <typename TExec, bool allowPrecisionLoss>
446461
struct DecimalDivideFunction {
447462
VELOX_DEFINE_FUNCTION_TYPES(TExec);
@@ -686,6 +701,14 @@ using DivideFunctionAllowPrecisionLoss = DecimalDivideFunction<TExec, true>;
686701
template <typename TExec>
687702
using DivideFunctionDenyPrecisionLoss = DecimalDivideFunction<TExec, false>;
688703

704+
template <typename TExec>
705+
using CheckedMultiplyFunctionAllowPrecisionLoss =
706+
CheckedDecimalMultiplyFunction<TExec, true>;
707+
708+
template <typename TExec>
709+
using CheckedMultiplyFunctionDenyPrecisionLoss =
710+
CheckedDecimalMultiplyFunction<TExec, false>;
711+
689712
std::vector<exec::SignatureVariable> getDivideConstraintsDenyPrecisionLoss() {
690713
std::string wholeDigits = fmt::format(
691714
"min(38, {a_precision} - {a_scale} + {b_scale})",
@@ -806,6 +829,11 @@ void registerDecimalMultiply(const std::string& prefix) {
806829
registerDecimalBinary<MultiplyFunctionDenyPrecisionLoss>(
807830
prefix + "multiply" + kDenyPrecisionLoss,
808831
makeConstraints(rPrecision, rScale, false));
832+
registerDecimalBinary<CheckedMultiplyFunctionAllowPrecisionLoss>(
833+
prefix + "checked_multiply", makeConstraints(rPrecision, rScale, true));
834+
registerDecimalBinary<CheckedMultiplyFunctionDenyPrecisionLoss>(
835+
prefix + "checked_multiply" + kDenyPrecisionLoss,
836+
makeConstraints(rPrecision, rScale, false));
809837
}
810838

811839
void registerDecimalDivide(const std::string& prefix) {

velox/functions/sparksql/tests/DecimalArithmeticTest.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest {
8888
std::optional<U> u) {
8989
return evaluateOnce<int64_t>("checked_div(c0, c1)", {tType, uType}, t, u);
9090
}
91+
92+
template <typename T, typename U>
93+
std::optional<int128_t> checked_multiply(
94+
const TypePtr& tType,
95+
const TypePtr& uType,
96+
std::optional<T> t,
97+
std::optional<U> u) {
98+
return evaluateOnce<int128_t>(
99+
"checked_multiply(c0, c1)", {tType, uType}, t, u);
100+
}
91101
};
92102

93103
TEST_F(DecimalArithmeticTest, add) {
@@ -833,5 +843,83 @@ TEST_F(DecimalArithmeticTest, checkedDiv) {
833843
1)),
834844
"Overflow in integral divide");
835845
}
846+
847+
TEST_F(DecimalArithmeticTest, checkedMultiply) {
848+
// Normal cases: DECIMAL(17,3) * DECIMAL(17,3) -> result precision 35 (long).
849+
// 1.000 * 2.000 = 2.000000 (unscaled: 1000 * 2000 = 2000000).
850+
EXPECT_EQ(
851+
(checked_multiply<int64_t, int64_t>(
852+
DECIMAL(17, 3), DECIMAL(17, 3), 1000, 2000)),
853+
2000000);
854+
EXPECT_EQ(
855+
(checked_multiply<int64_t, int128_t>(
856+
DECIMAL(17, 3), DECIMAL(20, 3), 1000, 2000)),
857+
2000000);
858+
EXPECT_EQ(
859+
(checked_multiply<int128_t, int64_t>(
860+
DECIMAL(20, 3), DECIMAL(17, 3), 1000, 2000)),
861+
2000000);
862+
EXPECT_EQ(
863+
(checked_multiply<int128_t, int128_t>(
864+
DECIMAL(20, 3), DECIMAL(20, 3), 1000, 2000)),
865+
2000000);
866+
867+
// Multiplying by zero.
868+
EXPECT_EQ(
869+
(checked_multiply<int64_t, int64_t>(
870+
DECIMAL(17, 3), DECIMAL(17, 3), 0, 2000)),
871+
0);
872+
873+
// Multiplying negative numbers: (-1.000) * 2.000 = -2.000000.
874+
EXPECT_EQ(
875+
(checked_multiply<int64_t, int64_t>(
876+
DECIMAL(17, 3), DECIMAL(17, 3), -1000, 2000)),
877+
-2000000);
878+
879+
// Result precision capped at 38, no overflow (small values).
880+
// DECIMAL(38,0) * DECIMAL(38,0) -> result precision capped at 38, scale 0.
881+
EXPECT_EQ(
882+
(checked_multiply<int128_t, int128_t>(
883+
DECIMAL(38, 0), DECIMAL(38, 0), 100, 200)),
884+
20000);
885+
886+
// Near-boundary success: large values that just fit.
887+
// 1e18 * 1e19 = 1e37, which fits in DECIMAL(38,0).
888+
EXPECT_EQ(
889+
(checked_multiply<int128_t, int128_t>(
890+
DECIMAL(38, 0),
891+
DECIMAL(38, 0),
892+
HugeInt::parse("1000000000000000000"),
893+
HugeInt::parse("10000000000000000000"))),
894+
HugeInt::parse("10000000000000000000000000000000000000"));
895+
896+
// Positive overflow should throw.
897+
// 1e19 * 1e19 = 1e38, which exceeds max DECIMAL(38,0).
898+
VELOX_ASSERT_USER_THROW(
899+
(checked_multiply<int128_t, int128_t>(
900+
DECIMAL(38, 0),
901+
DECIMAL(38, 0),
902+
HugeInt::parse("10000000000000000000"),
903+
HugeInt::parse("10000000000000000000"))),
904+
"Decimal overflow in multiply");
905+
906+
// Negative overflow should throw (positive * negative -> overflow).
907+
VELOX_ASSERT_USER_THROW(
908+
(checked_multiply<int128_t, int128_t>(
909+
DECIMAL(38, 0),
910+
DECIMAL(38, 0),
911+
HugeInt::parse("10000000000000000000"),
912+
HugeInt::parse("-10000000000000000000"))),
913+
"Decimal overflow in multiply");
914+
915+
// Negative * negative overflow should throw (result is positive but too large).
916+
VELOX_ASSERT_USER_THROW(
917+
(checked_multiply<int128_t, int128_t>(
918+
DECIMAL(38, 0),
919+
DECIMAL(38, 0),
920+
HugeInt::parse("-10000000000000000000"),
921+
HugeInt::parse("-10000000000000000000"))),
922+
"Decimal overflow in multiply");
923+
}
836924
} // namespace
837925
} // namespace facebook::velox::functions::sparksql::test

0 commit comments

Comments
 (0)