diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index a33df705b3..d043baf919 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -63,7 +63,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, create_modulo_expr, create_negate_expr, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, - SparkHour, SparkMinute, SparkSecond, + SparkHour, SparkMinute, SparkSecond, SumInteger, }; use crate::execution::operators::ExecutionError::GeneralError; @@ -1833,6 +1833,12 @@ impl PhysicalPlanner { let func = AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?); AggregateExprBuilder::new(Arc::new(func), vec![child]) } + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let func = + AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c9037dcd69..a7736f561a 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -120,7 +120,7 @@ message Count { message Sum { Expr child = 1; DataType datatype = 2; - bool fail_on_error = 3; + EvalMode eval_mode = 3; } message Min { diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 252da78890..b1027153e8 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -21,6 +21,7 @@ mod correlation; mod covariance; mod stddev; mod sum_decimal; +mod sum_int; mod variance; pub use avg::Avg; @@ -29,4 +30,5 @@ pub use correlation::Correlation; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; +pub use sum_int::SumInteger; pub use variance::Variance; diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs new file mode 100644 index 0000000000..92156b629e --- /dev/null +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -0,0 +1,573 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{arithmetic_overflow_error, EvalMode}; +use arrow::array::{ + cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, + Int64Array, PrimitiveArray, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, +}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use std::{any::Any, sync::Arc}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SumInteger { + signature: Signature, + eval_mode: EvalMode, +} + +impl SumInteger { + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { + signature: Signature::user_defined(Immutable), + eval_mode, + }), + _ => Err(DataFusionError::Internal( + "Invalid data type for SumInteger".into(), + )), + } + } +} + +impl AggregateUDFImpl for SumInteger { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + Arc::new(Field::new("sum", DataType::Int64, true)), + Arc::new(Field::new("has_all_nulls", DataType::Boolean, false)), + ]) + } else { + Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))]) + } + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode))) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct SumIntegerAccumulator { + sum: Option, + eval_mode: EvalMode, + has_all_nulls: bool, +} + +impl SumIntegerAccumulator { + fn new(eval_mode: EvalMode) -> Self { + if eval_mode == EvalMode::Try { + Self { + // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow + sum: Some(0), + has_all_nulls: true, + eval_mode, + } + } else { + Self { + sum: None, + has_all_nulls: false, + eval_mode, + } + } + } +} + +impl Accumulator for SumIntegerAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + // accumulator internal to add sum and return null sum (and has_nulls false) if there is an overflow in Try Eval mode + fn update_sum_internal( + int_array: &PrimitiveArray, + eval_mode: EvalMode, + mut sum: i64, + ) -> Result, DataFusionError> + where + T: ArrowPrimitiveType, + { + for i in 0..int_array.len() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().unwrap(); + match eval_mode { + EvalMode::Legacy => { + sum = v.add_wrapping(sum); + } + EvalMode::Ansi | EvalMode::Try => { + match v.add_checked(sum) { + Ok(v) => sum = v, + Err(_e) => { + return if eval_mode == EvalMode::Ansi { + Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))) + } else { + return Ok(None); + }; + } + }; + } + } + } + } + Ok(Some(sum)) + } + + if self.eval_mode == EvalMode::Try && !self.has_all_nulls && self.sum.is_none() { + // we saw an overflow earlier (Try eval mode). Skip processing + return Ok(()); + } + let values = &values[0]; + if values.len() == values.null_count() { + Ok(()) + } else { + // No nulls so there should be a non-null sum / null incase overflow in Try eval + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + DataType::Int32 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + DataType::Int16 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + DataType::Int8 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); + } + }; + self.sum = sum; + self.has_all_nulls = false; + Ok(()) + } + } + + fn evaluate(&mut self) -> DFResult { + if self.has_all_nulls { + Ok(ScalarValue::Int64(None)) + } else { + Ok(ScalarValue::Int64(self.sum)) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + ScalarValue::Int64(self.sum), + ScalarValue::Boolean(Some(self.has_all_nulls)), + ]) + } else { + Ok(vec![ScalarValue::Int64(self.sum)]) + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + // Check for overflow for early termination + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = states[1].as_boolean().value(0); + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = !self.has_all_nulls && self.sum.is_none(); + if that_overflowed || this_overflowed { + self.sum = None; + self.has_all_nulls = false; + return Ok(()); + } + self.has_all_nulls = self.has_all_nulls && that_has_all_nulls; + if that_has_all_nulls { + return Ok(()); + } + if self.has_all_nulls { + self.sum = that_sum; + return Ok(()); + } + } else { + if that_sum.is_none() { + return Ok(()); + } + if self.sum.is_none() { + self.sum = that_sum; + return Ok(()); + } + } + + let left = self.sum.unwrap(); + let right = that_sum.unwrap(); + + match self.eval_mode { + EvalMode::Legacy => { + self.sum = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) { + Ok(v) => self.sum = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); + } else { + self.sum = None; + self.has_all_nulls = false; + } + } + }, + } + Ok(()) + } +} + +struct SumIntGroupsAccumulator { + sums: Vec>, + has_all_nulls: Vec, + eval_mode: EvalMode, +} + +impl SumIntGroupsAccumulator { + fn new(eval_mode: EvalMode) -> Self { + Self { + sums: Vec::new(), + eval_mode, + has_all_nulls: Vec::new(), + } + } + + fn resize_helper(&mut self, total_num_groups: usize) { + if self.eval_mode == EvalMode::Try { + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); + } else { + self.sums.resize(total_num_groups, None); + self.has_all_nulls.resize(total_num_groups, false); + } + } +} + +impl GroupsAccumulator for SumIntGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + fn update_groups_sum_internal( + int_array: &PrimitiveArray, + group_indices: &[usize], + sums: &mut [Option], + has_all_nulls: &mut [bool], + eval_mode: EvalMode, + ) -> DFResult<()> + where + T: ArrowPrimitiveType, + T::Native: ArrowNativeType, + { + for (i, &group_index) in group_indices.iter().enumerate() { + if !int_array.is_null(i) { + // there is an overflow in prev group in try eval. Skip processing + if eval_mode == EvalMode::Try + && !has_all_nulls[group_index] + && sums[group_index].is_none() + { + continue; + } + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + match eval_mode { + EvalMode::Legacy => { + sums[group_index] = + Some(sums[group_index].unwrap_or(0).add_wrapping(v)); + } + EvalMode::Ansi | EvalMode::Try => { + match sums[group_index].unwrap_or(0).add_checked(v) { + Ok(new_sum) => { + sums[group_index] = Some(new_sum); + } + Err(_) => { + if eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from( + arithmetic_overflow_error("integer"), + )); + } else { + sums[group_index] = None; + } + } + }; + } + } + has_all_nulls[group_index] = false + } + } + Ok(()) + } + + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + let values = &values[0]; + self.resize_helper(total_num_groups); + + match values.data_type() { + DataType::Int64 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int32 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int16 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + DataType::Int8 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + &mut self.has_all_nulls, + self.eval_mode, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type for SumIntGroupsAccumulator: {:?}", + values.data_type() + ))) + } + }; + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + match emit_to { + EmitTo::All => { + let result = Arc::new(Int64Array::from_iter( + self.sums + .iter() + .zip(self.has_all_nulls.iter()) + .map(|(&sum, &is_null)| if is_null { None } else { sum }), + )) as ArrayRef; + + self.sums.clear(); + self.has_all_nulls.clear(); + Ok(result) + } + EmitTo::First(n) => { + let result = Arc::new(Int64Array::from_iter( + self.sums + .drain(..n) + .zip(self.has_all_nulls.drain(..n)) + .map(|(sum, is_null)| if is_null { None } else { sum }), + )) as ArrayRef; + Ok(result) + } + } + } + + fn state(&mut self, _emit_to: EmitTo) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + Arc::new(Int64Array::from(self.sums.clone())), + Arc::new(BooleanArray::from(self.has_all_nulls.clone())), + ]) + } else { + Ok(vec![Arc::new(Int64Array::from(self.sums.clone()))]) + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + let that_sums = values[0].as_primitive::(); + + self.resize_helper(total_num_groups); + + let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try { + Some(values[1].as_boolean()) + } else { + None + }; + + for (idx, &group_index) in group_indices.iter().enumerate() { + let that_sum = if that_sums.is_null(idx) { + None + } else { + Some(that_sums.value(idx)) + }; + + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = that_sums_is_all_nulls.unwrap().value(idx); + + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = + !self.has_all_nulls[group_index] && self.sums[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + continue; + } + + self.has_all_nulls[group_index] = + self.has_all_nulls[group_index] && that_has_all_nulls; + + if that_has_all_nulls { + continue; + } + + if self.has_all_nulls[group_index] { + self.sums[group_index] = that_sum; + continue; + } + } else { + if that_sum.is_none() { + continue; + } + if self.sums[group_index].is_none() { + self.sums[group_index] = that_sum; + continue; + } + } + + // Both sides have non-null. Update sums now + let left = self.sums[group_index].unwrap(); + let right = that_sum.unwrap(); + + match self.eval_mode { + EvalMode::Legacy => { + self.sums[group_index] = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => { + match left.add_checked(right) { + Ok(v) => self.sums[group_index] = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))); + } else { + // overflow. update flag accordingly + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + } + } + } + } + } + } + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 63e18c145a..b9c7847b09 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -62,6 +62,8 @@ import org.apache.comet.shims.CometExprShim */ object QueryPlanSerde extends Logging with CometExprShim { + private val integerTypes = Seq(ByteType, ShortType, IntegerType, LongType) + /** * Mapping of Spark operator class to Comet operator handler. */ @@ -414,7 +416,7 @@ object QueryPlanSerde extends Logging with CometExprShim { } case s: Sum => if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType - .isInstanceOf[DecimalType]) { + .isInstanceOf[DecimalType] && !integerTypes.contains(s.dataType)) { Some(agg) } else { withInfo(windowExpr, s"datatype ${s.dataType} is not supported", expr) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 4b8a74c15a..3af5c231b8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { @@ -201,19 +202,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { return None } - sum.evalMode match { - case EvalMode.ANSI if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() => - withInfo( - aggExpr, - "ANSI mode is not supported. Set " + - s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to allow it anyway") - return None - case EvalMode.TRY => - withInfo(aggExpr, "TRY mode is not supported") - return None - case _ => - // supported - } + val evalMode = sum.evalMode val childExpr = exprToProto(sum.child, inputs, binding) val dataType = serializeDataType(sum.dataType) @@ -222,7 +211,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { val builder = ExprOuterClass.Sum.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(sum.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode))) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index d502749380..b426e8ad8b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3002,6 +3002,227 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for try_sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT try_sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support for try_sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support - SUM function") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // Test long overflow + withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long overflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test long underflow + withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long underflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int SUM (should not overflow) + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + // Test Short SUM (should not overflow) + withParquetTable( + Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // Test Byte SUM (should not overflow) + withParquetTable( + Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), (10.toByte, 1.toByte)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + } + } + } + + test("ANSI support for SUM - GROUP BY") { + // Test Long overflow with GROUP BY to test GroupAccumulator with ANSI support + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + withParquetTable( + Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int with GROUP BY + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + // Test Short with GROUP BY + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + } + } + } + + test("try_sum overflow - with GROUP BY") { + // Test Long overflow with GROUP BY - some groups overflow while some don't + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // first group should return NULL (overflow) and group 2 should return 500 + checkSparkAnswerAndOperator(res) + } + + // Test Long underflow with GROUP BY + withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // first group should return NULL (underflow), second group should return neg 500 + checkSparkAnswerAndOperator(res) + } + + // Test all groups overflow + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (100L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Both groups should return NULL + checkSparkAnswerAndOperator(res) + } + + // Test Short with GROUP BY (should NOT overflow) + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY (no overflow) + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + } + test("test integral divide overflow for decimal") { if (isSpark40Plus) { Seq(true, false)