diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index a5f211b9bc64..206f6e08260d 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -349,6 +349,11 @@ Mathematical Functions Returns the negative of `x`. Corresponds to Spark's operator ``-``. + For integral types, when ``spark.ansi.enabled`` is true, throws an + arithmetic error if `x` is the minimum value of its type (e.g., -128 for + tinyint). When ``spark.ansi.enabled`` is false, returns the minimum value + unchanged. + .. spark:function:: unhex(x) -> varbinary Converts hexadecimal varchar ``x`` to varbinary. diff --git a/velox/functions/sparksql/Arithmetic.h b/velox/functions/sparksql/Arithmetic.h index 215af7ab9967..7057052dbac8 100644 --- a/velox/functions/sparksql/Arithmetic.h +++ b/velox/functions/sparksql/Arithmetic.h @@ -146,15 +146,34 @@ struct PModFloatFunction { template struct UnaryMinusFunction { template - FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a) { + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& /*inputTypes*/, + const core::QueryConfig& config, + const TInput* /*a*/) { + ansiEnabled_ = config.sparkAnsiEnabled(); + } + + template + FOLLY_ALWAYS_INLINE Status call(TInput& result, const TInput a) { if constexpr (std::is_integral_v) { - // Avoid undefined integer overflow. - result = a == std::numeric_limits::min() ? a : -a; - } else { - result = -a; + if (FOLLY_UNLIKELY(a == std::numeric_limits::min())) { + if (ansiEnabled_) { + if (threadSkipErrorDetails()) { + return Status::UserError(); + } + return Status::UserError("Arithmetic overflow: -({}).", a); + } + // In non-ANSI mode, returns the same negative minimum value. + result = a; + return Status::OK(); + } } - return true; + result = -a; + return Status::OK(); } + + private: + bool ansiEnabled_ = false; }; template diff --git a/velox/functions/sparksql/tests/ArithmeticTest.cpp b/velox/functions/sparksql/tests/ArithmeticTest.cpp index 7efb55ef0838..892f59180c01 100644 --- a/velox/functions/sparksql/tests/ArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/ArithmeticTest.cpp @@ -310,14 +310,39 @@ TEST_F(ArithmeticTest, UnaryMinus) { } TEST_F(ArithmeticTest, UnaryMinusOverflow) { + // Float/double cases are unaffected by ANSI mode. + for (const auto& ansiEnabled : {"false", "true"}) { + queryCtx_->testingOverrideConfigUnsafe( + {{core::QueryConfig::kSparkAnsiEnabled, ansiEnabled}}); + + EXPECT_EQ(unaryminus(-kInf), kInf); + EXPECT_TRUE(std::isnan(unaryminus(kNan).value_or(0))); + EXPECT_EQ(unaryminus(-kInf), kInf); + EXPECT_TRUE(std::isnan(unaryminus(kNan).value_or(0))); + } + + // With ANSI off, negating MIN returns MIN (wraps silently). + queryCtx_->testingOverrideConfigUnsafe( + {{core::QueryConfig::kSparkAnsiEnabled, "false"}}); + EXPECT_EQ(unaryminus(INT8_MIN), INT8_MIN); EXPECT_EQ(unaryminus(INT16_MIN), INT16_MIN); EXPECT_EQ(unaryminus(INT32_MIN), INT32_MIN); EXPECT_EQ(unaryminus(INT64_MIN), INT64_MIN); - EXPECT_EQ(unaryminus(-kInf), kInf); - EXPECT_TRUE(std::isnan(unaryminus(kNan).value_or(0))); - EXPECT_EQ(unaryminus(-kInf), kInf); - EXPECT_TRUE(std::isnan(unaryminus(kNan).value_or(0))); + + // With ANSI on, negating MIN throws. + queryCtx_->testingOverrideConfigUnsafe( + {{core::QueryConfig::kSparkAnsiEnabled, "true"}}); + + VELOX_ASSERT_THROW(unaryminus(INT8_MIN), "Arithmetic overflow"); + VELOX_ASSERT_THROW(unaryminus(INT16_MIN), "Arithmetic overflow"); + VELOX_ASSERT_THROW(unaryminus(INT32_MIN), "Arithmetic overflow"); + VELOX_ASSERT_THROW(unaryminus(INT64_MIN), "Arithmetic overflow"); + + // TRY wrapping returns null instead of throwing. + auto tryResult = + evaluateOnce("try(unaryminus(c0))", std::optional(INT64_MIN)); + EXPECT_FALSE(tryResult.has_value()); } TEST_F(ArithmeticTest, Divide) {