Skip to content

Commit 8805bc2

Browse files
committed
feat: Add checked_add and checked_subtract for Spark decimal arithmetic
Add checked decimal add and subtract functions that throw on overflow instead of returning null. These are needed for Spark's ANSI mode where arithmetic overflow should raise an error rather than silently produce null.
1 parent 1aee3b4 commit 8805bc2

File tree

3 files changed

+288
-0
lines changed

3 files changed

+288
-0
lines changed

velox/docs/functions/spark/decimal.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,40 @@ Arithmetic Functions
143143
Division by zero or overflow results in an error.
144144
Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to true.
145145

146+
.. spark:function:: add(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3)
147+
148+
Returns the result of adding ``x`` and ``y``. The result type is determined
149+
by the precision and scale computation rules described above.
150+
Returns NULL when the result overflows.
151+
Corresponds to Spark's operator ``+`` with ``spark.sql.ansi.enabled`` set to false. ::
152+
153+
SELECT CAST(1.1 as DECIMAL(3, 1)) + CAST(2.2 as DECIMAL(3, 1)); -- 3.3
154+
SELECT CAST('99999999999999999999999999999999999999' as DECIMAL(38, 0)) + CAST(1 as DECIMAL(38, 0)); -- NULL
155+
156+
.. spark:function:: checked_add(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3)
157+
158+
Returns the result of adding ``x`` and ``y``. The result type is determined
159+
by the precision and scale computation rules described above.
160+
Throws an error when the result overflows.
161+
Corresponds to Spark's operator ``+`` with ``spark.sql.ansi.enabled`` set to true.
162+
163+
.. spark:function:: subtract(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3)
164+
165+
Returns the result of subtracting ``y`` from ``x``. The result type is determined
166+
by the precision and scale computation rules described above.
167+
Returns NULL when the result overflows.
168+
Corresponds to Spark's operator ``-`` with ``spark.sql.ansi.enabled`` set to false. ::
169+
170+
SELECT CAST(1.1 as DECIMAL(3, 1)) - CAST(2.2 as DECIMAL(3, 1)); -- -1.1
171+
SELECT CAST('-99999999999999999999999999999999999999' as DECIMAL(38, 0)) - CAST(1 as DECIMAL(38, 0)); -- NULL
172+
173+
.. spark:function:: checked_subtract(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p3, s3)
174+
175+
Returns the result of subtracting ``y`` from ``x``. The result type is determined
176+
by the precision and scale computation rules described above.
177+
Throws an error when the result overflows.
178+
Corresponds to Spark's operator ``-`` with ``spark.sql.ansi.enabled`` set to true.
179+
146180
Decimal Functions
147181
-----------------
148182
.. spark:function:: ceil(x: decimal(p, s)) -> r: decimal(pr, 0)

velox/functions/sparksql/DecimalArithmetic.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,50 @@ struct DecimalSubtractFunction : DecimalAddSubtractBase {
328328
}
329329
};
330330

331+
// Decimal add function that returns error on overflow.
332+
template <typename TExec, bool allowPrecisionLoss>
333+
struct CheckedDecimalAddFunction : DecimalAddSubtractBase {
334+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
335+
336+
template <typename A, typename B>
337+
void initialize(
338+
const std::vector<TypePtr>& inputTypes,
339+
const core::QueryConfig& /*config*/,
340+
A* /*a*/,
341+
B* /*b*/) {
342+
initializeBase<allowPrecisionLoss>(inputTypes);
343+
}
344+
345+
template <typename R, typename A, typename B>
346+
Status call(R& out, const A& a, const B& b) {
347+
bool valid = applyAdd<R, A, B>(out, a, b);
348+
VELOX_USER_RETURN(!valid, "Decimal overflow in add");
349+
return Status::OK();
350+
}
351+
};
352+
353+
// Decimal subtract function that returns error on overflow.
354+
template <typename TExec, bool allowPrecisionLoss>
355+
struct CheckedDecimalSubtractFunction : DecimalAddSubtractBase {
356+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
357+
358+
template <typename A, typename B>
359+
void initialize(
360+
const std::vector<TypePtr>& inputTypes,
361+
const core::QueryConfig& /*config*/,
362+
A* /*a*/,
363+
B* /*b*/) {
364+
initializeBase<allowPrecisionLoss>(inputTypes);
365+
}
366+
367+
template <typename R, typename A, typename B>
368+
Status call(R& out, const A& a, const B& b) {
369+
bool valid = applyAdd<R, A, B>(out, a, B(-b));
370+
VELOX_USER_RETURN(!valid, "Decimal overflow in subtract");
371+
return Status::OK();
372+
}
373+
};
374+
331375
template <typename TExec, bool allowPrecisionLoss>
332376
struct DecimalMultiplyFunction {
333377
VELOX_DEFINE_FUNCTION_TYPES(TExec);
@@ -686,6 +730,22 @@ using DivideFunctionAllowPrecisionLoss = DecimalDivideFunction<TExec, true>;
686730
template <typename TExec>
687731
using DivideFunctionDenyPrecisionLoss = DecimalDivideFunction<TExec, false>;
688732

733+
template <typename TExec>
734+
using CheckedAddFunctionAllowPrecisionLoss =
735+
CheckedDecimalAddFunction<TExec, true>;
736+
737+
template <typename TExec>
738+
using CheckedAddFunctionDenyPrecisionLoss =
739+
CheckedDecimalAddFunction<TExec, false>;
740+
741+
template <typename TExec>
742+
using CheckedSubtractFunctionAllowPrecisionLoss =
743+
CheckedDecimalSubtractFunction<TExec, true>;
744+
745+
template <typename TExec>
746+
using CheckedSubtractFunctionDenyPrecisionLoss =
747+
CheckedDecimalSubtractFunction<TExec, false>;
748+
689749
std::vector<exec::SignatureVariable> getDivideConstraintsDenyPrecisionLoss() {
690750
std::string wholeDigits = fmt::format(
691751
"min(38, {a_precision} - {a_scale} + {b_scale})",
@@ -781,6 +841,11 @@ void registerDecimalAdd(const std::string& prefix) {
781841
registerDecimalBinary<AddFunctionDenyPrecisionLoss>(
782842
prefix + "add" + kDenyPrecisionLoss,
783843
makeConstraints(rPrecision, rScale, false));
844+
registerDecimalBinary<CheckedAddFunctionAllowPrecisionLoss>(
845+
prefix + "checked_add", makeConstraints(rPrecision, rScale, true));
846+
registerDecimalBinary<CheckedAddFunctionDenyPrecisionLoss>(
847+
prefix + "checked_add" + kDenyPrecisionLoss,
848+
makeConstraints(rPrecision, rScale, false));
784849
}
785850

786851
void registerDecimalSubtract(const std::string& prefix) {
@@ -790,6 +855,11 @@ void registerDecimalSubtract(const std::string& prefix) {
790855
registerDecimalBinary<SubtractFunctionDenyPrecisionLoss>(
791856
prefix + "subtract" + kDenyPrecisionLoss,
792857
makeConstraints(rPrecision, rScale, false));
858+
registerDecimalBinary<CheckedSubtractFunctionAllowPrecisionLoss>(
859+
prefix + "checked_subtract", makeConstraints(rPrecision, rScale, true));
860+
registerDecimalBinary<CheckedSubtractFunctionDenyPrecisionLoss>(
861+
prefix + "checked_subtract" + kDenyPrecisionLoss,
862+
makeConstraints(rPrecision, rScale, false));
793863
}
794864

795865
void registerDecimalMultiply(const std::string& prefix) {

velox/functions/sparksql/tests/DecimalArithmeticTest.cpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,25 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest {
8888
std::optional<U> u) {
8989
return evaluateOnce<int64_t>("checked_div(c0, c1)", {tType, uType}, t, u);
9090
}
91+
92+
template <typename T, typename U>
93+
std::optional<int128_t> checked_add(
94+
const TypePtr& tType,
95+
const TypePtr& uType,
96+
std::optional<T> t,
97+
std::optional<U> u) {
98+
return evaluateOnce<int128_t>("checked_add(c0, c1)", {tType, uType}, t, u);
99+
}
100+
101+
template <typename T, typename U>
102+
std::optional<int128_t> checked_subtract(
103+
const TypePtr& tType,
104+
const TypePtr& uType,
105+
std::optional<T> t,
106+
std::optional<U> u) {
107+
return evaluateOnce<int128_t>(
108+
"checked_subtract(c0, c1)", {tType, uType}, t, u);
109+
}
91110
};
92111

93112
TEST_F(DecimalArithmeticTest, add) {
@@ -833,5 +852,170 @@ TEST_F(DecimalArithmeticTest, checkedDiv) {
833852
1)),
834853
"Overflow in integral divide");
835854
}
855+
856+
TEST_F(DecimalArithmeticTest, checkedAdd) {
857+
// Normal cases should work.
858+
// Use DECIMAL(18, 2) so result precision (19) exceeds 18 (long decimal).
859+
EXPECT_EQ(
860+
(checked_add<int64_t, int64_t>(DECIMAL(18, 2), DECIMAL(18, 2), 100, 200)),
861+
300);
862+
EXPECT_EQ(
863+
(checked_add<int64_t, int128_t>(
864+
DECIMAL(18, 2), DECIMAL(20, 2), 100, 200)),
865+
300);
866+
EXPECT_EQ(
867+
(checked_add<int128_t, int64_t>(
868+
DECIMAL(20, 2), DECIMAL(18, 2), 100, 200)),
869+
300);
870+
EXPECT_EQ(
871+
(checked_add<int128_t, int128_t>(
872+
DECIMAL(20, 2), DECIMAL(20, 2), 100, 200)),
873+
300);
874+
875+
// Adding with zero.
876+
EXPECT_EQ(
877+
(checked_add<int64_t, int64_t>(DECIMAL(18, 2), DECIMAL(18, 2), 0, 100)),
878+
100);
879+
880+
// Adding negative numbers.
881+
EXPECT_EQ(
882+
(checked_add<int64_t, int64_t>(DECIMAL(18, 2), DECIMAL(18, 2), -100, 50)),
883+
-50);
884+
885+
// Result precision capped at 38, no overflow.
886+
EXPECT_EQ(
887+
(checked_add<int128_t, int128_t>(
888+
DECIMAL(38, 0), DECIMAL(38, 0), 100, 200)),
889+
300);
890+
891+
// Near-boundary success: large values through addLarge path, but fits.
892+
EXPECT_EQ(
893+
(checked_add<int128_t, int128_t>(
894+
DECIMAL(38, 0),
895+
DECIMAL(38, 0),
896+
HugeInt::parse("49999999999999999999999999999999999999"),
897+
HugeInt::parse("49999999999999999999999999999999999999"))),
898+
HugeInt::parse("99999999999999999999999999999999999998"));
899+
900+
// Positive overflow should throw.
901+
VELOX_ASSERT_USER_THROW(
902+
(checked_add<int128_t, int128_t>(
903+
DECIMAL(38, 0),
904+
DECIMAL(38, 0),
905+
HugeInt::parse("99999999999999999999999999999999999999"),
906+
HugeInt::parse("99999999999999999999999999999999999999"))),
907+
"Decimal overflow in add");
908+
909+
// Positive overflow with large positive and small positive.
910+
VELOX_ASSERT_USER_THROW(
911+
(checked_add<int128_t, int128_t>(
912+
DECIMAL(38, 0),
913+
DECIMAL(38, 0),
914+
HugeInt::parse("99999999999999999999999999999999999999"),
915+
1)),
916+
"Decimal overflow in add");
917+
918+
// Negative overflow should throw.
919+
VELOX_ASSERT_USER_THROW(
920+
(checked_add<int128_t, int128_t>(
921+
DECIMAL(38, 0),
922+
DECIMAL(38, 0),
923+
HugeInt::parse("-99999999999999999999999999999999999999"),
924+
HugeInt::parse("-99999999999999999999999999999999999999"))),
925+
"Decimal overflow in add");
926+
927+
// Negative overflow with large negative and small negative.
928+
VELOX_ASSERT_USER_THROW(
929+
(checked_add<int128_t, int128_t>(
930+
DECIMAL(38, 0),
931+
DECIMAL(38, 0),
932+
HugeInt::parse("-99999999999999999999999999999999999999"),
933+
-1)),
934+
"Decimal overflow in add");
935+
}
936+
937+
TEST_F(DecimalArithmeticTest, checkedSubtract) {
938+
// Normal cases should work.
939+
// Use DECIMAL(18, 2) so result precision (19) exceeds 18 (long decimal).
940+
EXPECT_EQ(
941+
(checked_subtract<int64_t, int64_t>(
942+
DECIMAL(18, 2), DECIMAL(18, 2), 300, 200)),
943+
100);
944+
EXPECT_EQ(
945+
(checked_subtract<int64_t, int128_t>(
946+
DECIMAL(18, 2), DECIMAL(20, 2), 300, 200)),
947+
100);
948+
EXPECT_EQ(
949+
(checked_subtract<int128_t, int64_t>(
950+
DECIMAL(20, 2), DECIMAL(18, 2), 300, 200)),
951+
100);
952+
EXPECT_EQ(
953+
(checked_subtract<int128_t, int128_t>(
954+
DECIMAL(20, 2), DECIMAL(20, 2), 300, 200)),
955+
100);
956+
957+
// Subtracting zero.
958+
EXPECT_EQ(
959+
(checked_subtract<int64_t, int64_t>(
960+
DECIMAL(18, 2), DECIMAL(18, 2), 100, 0)),
961+
100);
962+
963+
// Subtracting negative (effectively adding).
964+
EXPECT_EQ(
965+
(checked_subtract<int64_t, int64_t>(
966+
DECIMAL(18, 2), DECIMAL(18, 2), 100, -50)),
967+
150);
968+
969+
// Result precision capped at 38, no overflow.
970+
EXPECT_EQ(
971+
(checked_subtract<int128_t, int128_t>(
972+
DECIMAL(38, 0), DECIMAL(38, 0), 300, 200)),
973+
100);
974+
975+
// Near-boundary success: large values through addLarge path, but fits.
976+
EXPECT_EQ(
977+
(checked_subtract<int128_t, int128_t>(
978+
DECIMAL(38, 0),
979+
DECIMAL(38, 0),
980+
HugeInt::parse("49999999999999999999999999999999999999"),
981+
HugeInt::parse("-49999999999999999999999999999999999999"))),
982+
HugeInt::parse("99999999999999999999999999999999999998"));
983+
984+
// Negative overflow should throw.
985+
VELOX_ASSERT_USER_THROW(
986+
(checked_subtract<int128_t, int128_t>(
987+
DECIMAL(38, 0),
988+
DECIMAL(38, 0),
989+
HugeInt::parse("-99999999999999999999999999999999999999"),
990+
HugeInt::parse("99999999999999999999999999999999999999"))),
991+
"Decimal overflow in subtract");
992+
993+
// Negative overflow with large negative and small positive.
994+
VELOX_ASSERT_USER_THROW(
995+
(checked_subtract<int128_t, int128_t>(
996+
DECIMAL(38, 0),
997+
DECIMAL(38, 0),
998+
HugeInt::parse("-99999999999999999999999999999999999999"),
999+
1)),
1000+
"Decimal overflow in subtract");
1001+
1002+
// Positive overflow should throw.
1003+
VELOX_ASSERT_USER_THROW(
1004+
(checked_subtract<int128_t, int128_t>(
1005+
DECIMAL(38, 0),
1006+
DECIMAL(38, 0),
1007+
HugeInt::parse("99999999999999999999999999999999999999"),
1008+
HugeInt::parse("-99999999999999999999999999999999999999"))),
1009+
"Decimal overflow in subtract");
1010+
1011+
// Positive overflow with large positive and small negative.
1012+
VELOX_ASSERT_USER_THROW(
1013+
(checked_subtract<int128_t, int128_t>(
1014+
DECIMAL(38, 0),
1015+
DECIMAL(38, 0),
1016+
HugeInt::parse("99999999999999999999999999999999999999"),
1017+
-1)),
1018+
"Decimal overflow in subtract");
1019+
}
8361020
} // namespace
8371021
} // namespace facebook::velox::functions::sparksql::test

0 commit comments

Comments
 (0)