Skip to content

Commit 2bf2835

Browse files
authored
feat: Support ANSI mode avg expr (int inputs) (#2817)
1 parent bce61f5 commit 2bf2835

File tree

7 files changed

+99
-30
lines changed

7 files changed

+99
-30
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Comet will fall back to Spark for the following expressions when ANSI mode is en
3636
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See
3737
the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting.
3838

39-
- Average
39+
- Average (supports all numeric inputs except decimal types)
4040
- Cast (in some cases)
4141

4242
There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where we are tracking the work to fully implement ANSI support.

native/core/src/execution/planner.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,19 +1840,19 @@ impl PhysicalPlanner {
18401840
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
18411841
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
18421842
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
1843+
18431844
let builder = match datatype {
18441845
DataType::Decimal128(_, _) => {
18451846
let func =
18461847
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
18471848
AggregateExprBuilder::new(Arc::new(func), vec![child])
18481849
}
18491850
_ => {
1850-
// cast to the result data type of AVG if the result data type is different
1851-
// from the input type, e.g. AVG(Int32). We should not expect a cast
1852-
// failure since it should have already been checked at Spark side.
1851+
// For all other numeric types (Int8/16/32/64, Float32/64):
1852+
// Cast to Float64 for accumulation
18531853
let child: Arc<dyn PhysicalExpr> =
1854-
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
1855-
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
1854+
Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None));
1855+
let func = AggregateUDF::new_from_impl(Avg::new("avg", DataType::Float64));
18561856
AggregateExprBuilder::new(Arc::new(func), vec![child])
18571857
}
18581858
};

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ message Avg {
138138
Expr child = 1;
139139
DataType datatype = 2;
140140
DataType sum_datatype = 3;
141-
bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode)
141+
EvalMode eval_mode = 4;
142142
}
143143

144144
message First {

native/spark-expr/src/agg_funcs/avg.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ impl AggregateUDFImpl for Avg {
7373
}
7474

7575
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
76-
// instantiate specialized accumulator based for the type
76+
// All numeric types use Float64 accumulation after casting
7777
match (&self.input_data_type, &self.result_data_type) {
7878
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
7979
_ => not_impl_err!(
@@ -115,7 +115,6 @@ impl AggregateUDFImpl for Avg {
115115
&self,
116116
_args: AccumulatorArgs,
117117
) -> Result<Box<dyn GroupsAccumulator>> {
118-
// instantiate specialized accumulator based for the type
119118
match (&self.input_data_type, &self.result_data_type) {
120119
(Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
121120
&self.input_data_type,
@@ -172,7 +171,7 @@ impl Accumulator for AvgAccumulator {
172171
// counts are summed
173172
self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
174173

175-
// sums are summed
174+
// sums are summed - no overflow checking in all Eval Modes
176175
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
177176
let v = self.sum.get_or_insert(0.);
178177
*v += x;
@@ -182,7 +181,7 @@ impl Accumulator for AvgAccumulator {
182181

183182
fn evaluate(&mut self) -> Result<ScalarValue> {
184183
if self.count == 0 {
185-
// If all input are nulls, count will be 0 and we will get null after the division.
184+
// If all input are nulls, count will be 0, and we will get null after the division.
186185
// This is consistent with Spark Average implementation.
187186
Ok(ScalarValue::Float64(None))
188187
} else {
@@ -198,7 +197,8 @@ impl Accumulator for AvgAccumulator {
198197
}
199198

200199
/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
201-
/// Stores values as native types, and does overflow checking
200+
/// Stores values as native types (
201+
/// no overflow check all eval modes since inf is a perfectly valid value per spark impl)
202202
///
203203
/// F: Function that calculates the average value from a sum of
204204
/// T::Native and a total count
@@ -260,6 +260,7 @@ where
260260
if values.null_count() == 0 {
261261
for (&group_index, &value) in iter {
262262
let sum = &mut self.sums[group_index];
263+
// No overflow checking - Infinity is a valid result
263264
*sum = (*sum).add_wrapping(value);
264265
self.counts[group_index] += 1;
265266
}
@@ -296,7 +297,7 @@ where
296297
self.counts[group_index] += partial_count;
297298
}
298299

299-
// update sums
300+
// update sums - no overflow checking (in all eval modes)
300301
self.sums.resize(total_num_groups, T::default_value());
301302
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
302303
for (&group_index, &new_value) in iter2 {
@@ -325,7 +326,6 @@ where
325326
Ok(Arc::new(array))
326327
}
327328

328-
// return arrays for sums and counts
329329
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
330330
let counts = emit_to.take_needed(&mut self.counts);
331331
let counts = Int64Array::new(counts.into(), None);

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.jdk.CollectionConverters._
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode}
24+
import org.apache.spark.sql.catalyst.expressions.Attribute
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
@@ -151,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] {
151151

152152
object CometAverage extends CometAggregateExpressionSerde[Average] {
153153

154-
override def getSupportLevel(avg: Average): SupportLevel = {
155-
avg.evalMode match {
156-
case EvalMode.ANSI =>
157-
Incompatible(Some("ANSI mode is not supported"))
158-
case EvalMode.TRY =>
159-
Incompatible(Some("TRY mode is not supported"))
160-
case _ =>
161-
Compatible()
162-
}
163-
}
164-
165154
override def convert(
166155
aggExpr: AggregateExpression,
167156
avg: Average,
@@ -193,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
193182
val builder = ExprOuterClass.Avg.newBuilder()
194183
builder.setChild(childExpr.get)
195184
builder.setDatatype(dataType.get)
196-
builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
185+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
197186
builder.setSumDatatype(sumDataType.get)
198187

199188
Some(

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,89 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14711471
}
14721472
}
14731473

1474+
test("AVG and try_avg - basic functionality") {
1475+
withParquetTable(
1476+
Seq(
1477+
(10L, 1),
1478+
(20L, 1),
1479+
(null.asInstanceOf[Long], 1),
1480+
(100L, 2),
1481+
(200L, 2),
1482+
(null.asInstanceOf[Long], 3)),
1483+
"tbl") {
1484+
1485+
Seq(true, false).foreach({ ansiMode =>
1486+
// without GROUP BY
1487+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
1488+
val res = sql("SELECT avg(_1) FROM tbl")
1489+
checkSparkAnswerAndOperator(res)
1490+
}
1491+
1492+
// with GROUP BY
1493+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
1494+
val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
1495+
checkSparkAnswerAndOperator(res)
1496+
}
1497+
})
1498+
1499+
// try_avg without GROUP BY
1500+
val resTry = sql("SELECT try_avg(_1) FROM tbl")
1501+
checkSparkAnswerAndOperator(resTry)
1502+
1503+
// try_avg with GROUP BY
1504+
val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
1505+
checkSparkAnswerAndOperator(resTryGroup)
1506+
1507+
}
1508+
}
1509+
1510+
test("AVG and try_avg - special numbers") {
1511+
1512+
val negativeNumbers: Seq[(Long, Int)] = Seq(
1513+
(-1L, 1),
1514+
(-123L, 1),
1515+
(-456L, 1),
1516+
(Long.MinValue, 1),
1517+
(Long.MinValue, 1),
1518+
(Long.MinValue, 2),
1519+
(Long.MinValue, 2),
1520+
(null.asInstanceOf[Long], 3))
1521+
1522+
val zeroSeq: Seq[(Long, Int)] =
1523+
Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3))
1524+
1525+
val highValNumbers: Seq[(Long, Int)] = Seq(
1526+
(Long.MaxValue, 1),
1527+
(Long.MaxValue, 1),
1528+
(Long.MaxValue, 2),
1529+
(Long.MaxValue, 2),
1530+
(null.asInstanceOf[Long], 3))
1531+
1532+
val inputs = Seq(negativeNumbers, highValNumbers, zeroSeq)
1533+
inputs.foreach(inputSeq => {
1534+
withParquetTable(inputSeq, "tbl") {
1535+
Seq(true, false).foreach({ ansiMode =>
1536+
// without GROUP BY
1537+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
1538+
checkSparkAnswerAndOperator("SELECT avg(_1) FROM tbl")
1539+
}
1540+
1541+
// with GROUP BY
1542+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
1543+
checkSparkAnswerAndOperator("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
1544+
}
1545+
})
1546+
1547+
// try_avg without GROUP BY
1548+
checkSparkAnswerAndOperator("SELECT try_avg(_1) FROM tbl")
1549+
1550+
// try_avg with GROUP BY
1551+
checkSparkAnswerAndOperator("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
1552+
1553+
}
1554+
})
1555+
}
1556+
14741557
test("ANSI support for sum - null test") {
14751558
Seq(true, false).foreach { ansiEnabled =>
14761559
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {

spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import org.apache.spark.SparkContext
2929
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE}
3030
import org.apache.spark.sql.TPCDSBase
3131
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
32-
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
3332
import org.apache.spark.sql.catalyst.util.resourceToString
3433
import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec}
3534
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -226,8 +225,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa
226225
CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false",
227226
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
228227
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true",
229-
// Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support
230-
CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
231228
// as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
232229
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
233230
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {

0 commit comments

Comments
 (0)