Skip to content

Commit e3d80dc

Browse files
committed
feat: Add checked_add and checked_subtract for Spark decimal arithmetic
Add checked decimal add and subtract functions that throw on overflow instead of returning null. These are needed for Spark's ANSI mode where arithmetic overflow should raise an error rather than silently produce null.
1 parent 022acc1 commit e3d80dc

File tree

2 files changed

+253
-0
lines changed

2 files changed

+253
-0
lines changed

velox/functions/sparksql/DecimalArithmetic.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,50 @@ struct DecimalSubtractFunction : DecimalAddSubtractBase {
328328
}
329329
};
330330

331+
// Decimal add function that returns error on overflow.
332+
template <typename TExec, bool allowPrecisionLoss>
333+
struct CheckedDecimalAddFunction : DecimalAddSubtractBase {
334+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
335+
336+
template <typename A, typename B>
337+
void initialize(
338+
const std::vector<TypePtr>& inputTypes,
339+
const core::QueryConfig& /*config*/,
340+
A* /*a*/,
341+
B* /*b*/) {
342+
initializeBase<allowPrecisionLoss>(inputTypes);
343+
}
344+
345+
template <typename R, typename A, typename B>
346+
Status call(R& out, const A& a, const B& b) {
347+
auto overflow = applyAdd<R, A, B>(out, a, b);
348+
VELOX_USER_RETURN(!overflow, "Decimal overflow in add");
349+
return Status::OK();
350+
}
351+
};
352+
353+
// Decimal subtract function that returns error on overflow.
354+
template <typename TExec, bool allowPrecisionLoss>
355+
struct CheckedDecimalSubtractFunction : DecimalAddSubtractBase {
356+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
357+
358+
template <typename A, typename B>
359+
void initialize(
360+
const std::vector<TypePtr>& inputTypes,
361+
const core::QueryConfig& /*config*/,
362+
A* /*a*/,
363+
B* /*b*/) {
364+
initializeBase<allowPrecisionLoss>(inputTypes);
365+
}
366+
367+
template <typename R, typename A, typename B>
368+
Status call(R& out, const A& a, const B& b) {
369+
auto overflow = applyAdd<R, A, B>(out, a, B(-b));
370+
VELOX_USER_RETURN(!overflow, "Decimal overflow in subtract");
371+
return Status::OK();
372+
}
373+
};
374+
331375
template <typename TExec, bool allowPrecisionLoss>
332376
struct DecimalMultiplyFunction {
333377
VELOX_DEFINE_FUNCTION_TYPES(TExec);
@@ -686,6 +730,22 @@ using DivideFunctionAllowPrecisionLoss = DecimalDivideFunction<TExec, true>;
686730
template <typename TExec>
687731
using DivideFunctionDenyPrecisionLoss = DecimalDivideFunction<TExec, false>;
688732

733+
template <typename TExec>
734+
using CheckedAddFunctionAllowPrecisionLoss =
735+
CheckedDecimalAddFunction<TExec, true>;
736+
737+
template <typename TExec>
738+
using CheckedAddFunctionDenyPrecisionLoss =
739+
CheckedDecimalAddFunction<TExec, false>;
740+
741+
template <typename TExec>
742+
using CheckedSubtractFunctionAllowPrecisionLoss =
743+
CheckedDecimalSubtractFunction<TExec, true>;
744+
745+
template <typename TExec>
746+
using CheckedSubtractFunctionDenyPrecisionLoss =
747+
CheckedDecimalSubtractFunction<TExec, false>;
748+
689749
std::vector<exec::SignatureVariable> getDivideConstraintsDenyPrecisionLoss() {
690750
std::string wholeDigits = fmt::format(
691751
"min(38, {a_precision} - {a_scale} + {b_scale})",
@@ -781,6 +841,11 @@ void registerDecimalAdd(const std::string& prefix) {
781841
registerDecimalBinary<AddFunctionDenyPrecisionLoss>(
782842
prefix + "add" + kDenyPrecisionLoss,
783843
makeConstraints(rPrecision, rScale, false));
844+
registerDecimalBinary<CheckedAddFunctionAllowPrecisionLoss>(
845+
prefix + "checked_add", makeConstraints(rPrecision, rScale, true));
846+
registerDecimalBinary<CheckedAddFunctionDenyPrecisionLoss>(
847+
prefix + "checked_add" + kDenyPrecisionLoss,
848+
makeConstraints(rPrecision, rScale, false));
784849
}
785850

786851
void registerDecimalSubtract(const std::string& prefix) {
@@ -790,6 +855,11 @@ void registerDecimalSubtract(const std::string& prefix) {
790855
registerDecimalBinary<SubtractFunctionDenyPrecisionLoss>(
791856
prefix + "subtract" + kDenyPrecisionLoss,
792857
makeConstraints(rPrecision, rScale, false));
858+
registerDecimalBinary<CheckedSubtractFunctionAllowPrecisionLoss>(
859+
prefix + "checked_subtract", makeConstraints(rPrecision, rScale, true));
860+
registerDecimalBinary<CheckedSubtractFunctionDenyPrecisionLoss>(
861+
prefix + "checked_subtract" + kDenyPrecisionLoss,
862+
makeConstraints(rPrecision, rScale, false));
793863
}
794864

795865
void registerDecimalMultiply(const std::string& prefix) {
@@ -821,4 +891,5 @@ void registerDecimalIntegralDivide(const std::string& prefix) {
821891
registerIntegralDecimalDivide<CheckedDecimalIntegralDivideFunction>(
822892
prefix + "checked_div");
823893
}
894+
824895
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/tests/DecimalArithmeticTest.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,25 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest {
8888
std::optional<U> u) {
8989
return evaluateOnce<int64_t>("checked_div(c0, c1)", {tType, uType}, t, u);
9090
}
91+
92+
template <typename T, typename U>
93+
std::optional<int128_t> checked_add(
94+
const TypePtr& tType,
95+
const TypePtr& uType,
96+
std::optional<T> t,
97+
std::optional<U> u) {
98+
return evaluateOnce<int128_t>("checked_add(c0, c1)", {tType, uType}, t, u);
99+
}
100+
101+
template <typename T, typename U>
102+
std::optional<int128_t> checked_subtract(
103+
const TypePtr& tType,
104+
const TypePtr& uType,
105+
std::optional<T> t,
106+
std::optional<U> u) {
107+
return evaluateOnce<int128_t>(
108+
"checked_subtract(c0, c1)", {tType, uType}, t, u);
109+
}
91110
};
92111

93112
TEST_F(DecimalArithmeticTest, add) {
@@ -833,5 +852,168 @@ TEST_F(DecimalArithmeticTest, checkedDiv) {
833852
1)),
834853
"Overflow in integral divide");
835854
}
855+
856+
TEST_F(DecimalArithmeticTest, checkedAdd) {
857+
// Normal cases should work.
858+
// Use DECIMAL(18, 2) so result precision (19) exceeds 18 (long decimal).
859+
EXPECT_EQ(
860+
(checked_add<int64_t, int64_t>(DECIMAL(18, 2), DECIMAL(18, 2), 100, 200)),
861+
300);
862+
EXPECT_EQ(
863+
(checked_add<int64_t, int128_t>(DECIMAL(18, 2), DECIMAL(20, 2), 100, 200)),
864+
300);
865+
EXPECT_EQ(
866+
(checked_add<int128_t, int64_t>(DECIMAL(20, 2), DECIMAL(18, 2), 100, 200)),
867+
300);
868+
EXPECT_EQ(
869+
(checked_add<int128_t, int128_t>(
870+
DECIMAL(20, 2), DECIMAL(20, 2), 100, 200)),
871+
300);
872+
873+
// Adding with zero.
874+
EXPECT_EQ(
875+
(checked_add<int64_t, int64_t>(DECIMAL(18, 2), DECIMAL(18, 2), 0, 100)),
876+
100);
877+
878+
// Adding negative numbers.
879+
EXPECT_EQ(
880+
(checked_add<int64_t, int64_t>(DECIMAL(18, 2), DECIMAL(18, 2), -100, 50)),
881+
-50);
882+
883+
// Result precision capped at 38, no overflow.
884+
EXPECT_EQ(
885+
(checked_add<int128_t, int128_t>(
886+
DECIMAL(38, 0), DECIMAL(38, 0), 100, 200)),
887+
300);
888+
889+
// Near-boundary success: large values through addLarge path, but fits.
890+
EXPECT_EQ(
891+
(checked_add<int128_t, int128_t>(
892+
DECIMAL(38, 0),
893+
DECIMAL(38, 0),
894+
HugeInt::parse("49999999999999999999999999999999999999"),
895+
HugeInt::parse("49999999999999999999999999999999999999"))),
896+
HugeInt::parse("99999999999999999999999999999999999998"));
897+
898+
// Positive overflow should throw.
899+
VELOX_ASSERT_USER_THROW(
900+
(checked_add<int128_t, int128_t>(
901+
DECIMAL(38, 0),
902+
DECIMAL(38, 0),
903+
HugeInt::parse("99999999999999999999999999999999999999"),
904+
HugeInt::parse("99999999999999999999999999999999999999"))),
905+
"Decimal overflow in add");
906+
907+
// Positive overflow with large positive and small positive.
908+
VELOX_ASSERT_USER_THROW(
909+
(checked_add<int128_t, int128_t>(
910+
DECIMAL(38, 0),
911+
DECIMAL(38, 0),
912+
HugeInt::parse("99999999999999999999999999999999999999"),
913+
1)),
914+
"Decimal overflow in add");
915+
916+
// Negative overflow should throw.
917+
VELOX_ASSERT_USER_THROW(
918+
(checked_add<int128_t, int128_t>(
919+
DECIMAL(38, 0),
920+
DECIMAL(38, 0),
921+
HugeInt::parse("-99999999999999999999999999999999999999"),
922+
HugeInt::parse("-99999999999999999999999999999999999999"))),
923+
"Decimal overflow in add");
924+
925+
// Negative overflow with large negative and small negative.
926+
VELOX_ASSERT_USER_THROW(
927+
(checked_add<int128_t, int128_t>(
928+
DECIMAL(38, 0),
929+
DECIMAL(38, 0),
930+
HugeInt::parse("-99999999999999999999999999999999999999"),
931+
-1)),
932+
"Decimal overflow in add");
933+
}
934+
935+
TEST_F(DecimalArithmeticTest, checkedSubtract) {
936+
// Normal cases should work.
937+
// Use DECIMAL(18, 2) so result precision (19) exceeds 18 (long decimal).
938+
EXPECT_EQ(
939+
(checked_subtract<int64_t, int64_t>(
940+
DECIMAL(18, 2), DECIMAL(18, 2), 300, 200)),
941+
100);
942+
EXPECT_EQ(
943+
(checked_subtract<int64_t, int128_t>(
944+
DECIMAL(18, 2), DECIMAL(20, 2), 300, 200)),
945+
100);
946+
EXPECT_EQ(
947+
(checked_subtract<int128_t, int64_t>(
948+
DECIMAL(20, 2), DECIMAL(18, 2), 300, 200)),
949+
100);
950+
EXPECT_EQ(
951+
(checked_subtract<int128_t, int128_t>(
952+
DECIMAL(20, 2), DECIMAL(20, 2), 300, 200)),
953+
100);
954+
955+
// Subtracting zero.
956+
EXPECT_EQ(
957+
(checked_subtract<int64_t, int64_t>(
958+
DECIMAL(18, 2), DECIMAL(18, 2), 100, 0)),
959+
100);
960+
961+
// Subtracting negative (effectively adding).
962+
EXPECT_EQ(
963+
(checked_subtract<int64_t, int64_t>(
964+
DECIMAL(18, 2), DECIMAL(18, 2), 100, -50)),
965+
150);
966+
967+
// Result precision capped at 38, no overflow.
968+
EXPECT_EQ(
969+
(checked_subtract<int128_t, int128_t>(
970+
DECIMAL(38, 0), DECIMAL(38, 0), 300, 200)),
971+
100);
972+
973+
// Near-boundary success: large values through addLarge path, but fits.
974+
EXPECT_EQ(
975+
(checked_subtract<int128_t, int128_t>(
976+
DECIMAL(38, 0),
977+
DECIMAL(38, 0),
978+
HugeInt::parse("49999999999999999999999999999999999999"),
979+
HugeInt::parse("-49999999999999999999999999999999999999"))),
980+
HugeInt::parse("99999999999999999999999999999999999998"));
981+
982+
// Negative overflow should throw.
983+
VELOX_ASSERT_USER_THROW(
984+
(checked_subtract<int128_t, int128_t>(
985+
DECIMAL(38, 0),
986+
DECIMAL(38, 0),
987+
HugeInt::parse("-99999999999999999999999999999999999999"),
988+
HugeInt::parse("99999999999999999999999999999999999999"))),
989+
"Decimal overflow in subtract");
990+
991+
// Negative overflow with large negative and small positive.
992+
VELOX_ASSERT_USER_THROW(
993+
(checked_subtract<int128_t, int128_t>(
994+
DECIMAL(38, 0),
995+
DECIMAL(38, 0),
996+
HugeInt::parse("-99999999999999999999999999999999999999"),
997+
1)),
998+
"Decimal overflow in subtract");
999+
1000+
// Positive overflow should throw.
1001+
VELOX_ASSERT_USER_THROW(
1002+
(checked_subtract<int128_t, int128_t>(
1003+
DECIMAL(38, 0),
1004+
DECIMAL(38, 0),
1005+
HugeInt::parse("99999999999999999999999999999999999999"),
1006+
HugeInt::parse("-99999999999999999999999999999999999999"))),
1007+
"Decimal overflow in subtract");
1008+
1009+
// Positive overflow with large positive and small negative.
1010+
VELOX_ASSERT_USER_THROW(
1011+
(checked_subtract<int128_t, int128_t>(
1012+
DECIMAL(38, 0),
1013+
DECIMAL(38, 0),
1014+
HugeInt::parse("99999999999999999999999999999999999999"),
1015+
-1)),
1016+
"Decimal overflow in subtract");
1017+
}
8361018
} // namespace
8371019
} // namespace facebook::velox::functions::sparksql::test

0 commit comments

Comments
 (0)