From 3bc09b175ac005c996eb7e4d6e97bd913394a0e9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 17:56:18 -0600 Subject: [PATCH 01/12] implement custom CountNotNull aggregate --- native/core/src/execution/planner.rs | 48 +--- .../src/agg_funcs/count_not_null.rs | 239 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../comet/exec/CometAggregateSuite.scala | 25 ++ 4 files changed, 274 insertions(+), 40 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/count_not_null.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 6051a459e3..0782344d97 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -109,10 +109,10 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, - RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr, - ToJson, UnboundColumn, Variance, + ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, CountNotNull, Covariance, + CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, + NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, + SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -1903,42 +1903,10 @@ impl PhysicalPlanner { ) -> Result { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { - assert!(!expr.children.is_empty()); - // Using `count_udaf` from Comet is exceptionally slow for some reason, so - // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` - // https://github.com/apache/datafusion-comet/issues/744 - - let children = expr - .children - .iter() - .map(|child| self.create_expr(child, Arc::clone(&schema))) - .collect::, _>>()?; - - // create `IS NOT NULL expr` and join them with `AND` if there are multiple - let not_null_expr: Arc = children.iter().skip(1).fold( - Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc, - |acc, child| { - Arc::new(BinaryExpr::new( - acc, - DataFusionOperator::And, - Arc::new(IsNotNullExpr::new(Arc::clone(child))), - )) - }, - ); - - let child = Arc::new(IfExpr::new( - not_null_expr, - Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), - Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), - )); - - AggregateExprBuilder::new(sum_udaf(), vec![child]) - .schema(schema) - .alias("count") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + assert_eq!(expr.children.len(), 1); + let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(CountNotNull::new()); + Self::create_aggr_func_expr("count", schema, vec![child], func) } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; diff --git a/native/spark-expr/src/agg_funcs/count_not_null.rs b/native/spark-expr/src/agg_funcs/count_not_null.rs new file mode 100644 index 0000000000..4aea0b33c0 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/count_not_null.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, TypeSignature, Volatility, +}; + +/// CountNotNull aggregate function +/// Counts the number of non-null values in the input expression +#[derive(Debug)] +pub struct CountNotNull { + signature: Signature, +} + +impl Default for CountNotNull { + fn default() -> Self { + Self::new() + } +} + +impl CountNotNull { + pub fn new() -> Self { + Self { + signature: Signature::one_of(vec![TypeSignature::Any(1)], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CountNotNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "count_not_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(CountNotNullAccumulator::new())) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(CountNotNullGroupsAccumulator::new())) + } +} + +#[derive(Debug)] +struct CountNotNullAccumulator { + count: i64, +} + +impl CountNotNullAccumulator { + fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountNotNullAccumulator { + fn update_batch(&mut self, values: &[Arc]) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + let array = &values[0]; + let non_null_count = array.len() - array.null_count(); + self.count += non_null_count as i64; + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn merge_batch(&mut self, states: &[Arc]) -> DFResult<()> { + if states.is_empty() { + return Ok(()); + } + + let counts = states[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected Int64Array".to_string()))?; + + for i in 0..counts.len() { + if let Some(count) = counts.value(i).into() { + self.count += count; + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct CountNotNullGroupsAccumulator { + counts: Vec, +} + +impl CountNotNullGroupsAccumulator { + fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountNotNullGroupsAccumulator { + fn update_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + let array = &values[0]; + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Check filter if present + if let Some(filter) = opt_filter { + if !filter.value(row_idx) { + continue; + } + } + + // Check if value is not null + if !array.is_null(row_idx) { + self.counts[group_idx] += 1; + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult>> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(vec![Arc::new(result)]) + } + + fn merge_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + let counts_array = values[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected Int64Array".to_string()))?; + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Check filter if present + if let Some(filter) = opt_filter { + if !filter.value(row_idx) { + continue; + } + } + + if let Some(count) = counts_array.value(row_idx).into() { + self.counts[group_idx] += count; + } + } + + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of::() + self.counts.capacity() * std::mem::size_of::() + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 252da78890..7f51e489bc 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -18,6 +18,7 @@ mod avg; mod avg_decimal; mod correlation; +mod count_not_null; mod covariance; mod stddev; mod sum_decimal; @@ -26,6 +27,7 @@ mod variance; pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; +pub use count_not_null::CountNotNull; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 6574d9568d..7ce3885b1d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -90,6 +90,31 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("count with null values") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + Seq((1, null), (2, 42), (null, 43), (4, null), (5, 44)), + "count_null_test", + dictionaryEnabled) { + // Test COUNT on different columns (COUNT should ignore nulls) + checkSparkAnswerAndOperator(sql("SELECT COUNT(_1) FROM count_null_test")) + checkSparkAnswerAndOperator(sql("SELECT COUNT(_2) FROM count_null_test")) + + // Test with GROUP BY + checkSparkAnswerAndOperator( + sql("SELECT _1, COUNT(_2) FROM count_null_test GROUP BY _1 ORDER BY _1")) + + // Test combined with other aggregates + checkSparkAnswerAndOperator(sql("SELECT COUNT(_1), COUNT(_2) FROM count_null_test")) + } + } + } + } + test("lead/lag should return the default value if the offset row does not exist") { withSQLConf( CometConf.COMET_ENABLED.key -> "true", From deb44c421a3b79834675394a52e8d33d1d2d0e3e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 18:29:16 -0600 Subject: [PATCH 02/12] count rows --- native/core/src/execution/planner.rs | 18 +- native/spark-expr/src/agg_funcs/count_rows.rs | 239 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../org/apache/comet/serde/aggregates.scala | 4 + .../comet/exec/CometAggregateSuite.scala | 3 + 5 files changed, 263 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/count_rows.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0782344d97..104a09ba9a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -109,8 +109,8 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, CountNotNull, Covariance, - CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, + ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, CountNotNull, CountRows, + Covariance, CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; @@ -1905,7 +1905,19 @@ impl PhysicalPlanner { AggExprStruct::Count(expr) => { assert_eq!(expr.children.len(), 1); let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; - let func = AggregateUDF::new_from_impl(CountNotNull::new()); + + // Check if the child is a literal (for COUNT(*) which is COUNT(1)) + let is_literal = + matches!(child.as_any().downcast_ref::(), Some(_)); + + let func = if is_literal { + // COUNT(*) - count all rows including nulls + AggregateUDF::new_from_impl(CountRows::new()) + } else { + // COUNT(column) - count only non-null values + AggregateUDF::new_from_impl(CountNotNull::new()) + }; + Self::create_aggr_func_expr("count", schema, vec![child], func) } AggExprStruct::Min(expr) => { diff --git a/native/spark-expr/src/agg_funcs/count_rows.rs b/native/spark-expr/src/agg_funcs/count_rows.rs new file mode 100644 index 0000000000..b0d78ef199 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/count_rows.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use datafusion::common::{Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, TypeSignature, Volatility, +}; + +/// CountRows aggregate function +/// Counts all rows including those with null values (equivalent to COUNT(*)) +#[derive(Debug)] +pub struct CountRows { + signature: Signature, +} + +impl Default for CountRows { + fn default() -> Self { + Self::new() + } +} + +impl CountRows { + pub fn new() -> Self { + Self { + signature: Signature::one_of(vec![TypeSignature::Any(1)], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CountRows { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "count_rows" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(CountRowsAccumulator::new())) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(CountRowsGroupsAccumulator::new())) + } +} + +#[derive(Debug)] +struct CountRowsAccumulator { + count: i64, +} + +impl CountRowsAccumulator { + fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountRowsAccumulator { + fn update_batch(&mut self, values: &[Arc]) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Count all rows regardless of null values + let array = &values[0]; + self.count += array.len() as i64; + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn merge_batch(&mut self, states: &[Arc]) -> DFResult<()> { + if states.is_empty() { + return Ok(()); + } + + let counts = states[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal("Expected Int64Array".to_string()) + })?; + + for i in 0..counts.len() { + if let Some(count) = counts.value(i).into() { + self.count += count; + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct CountRowsGroupsAccumulator { + counts: Vec, +} + +impl CountRowsGroupsAccumulator { + fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountRowsGroupsAccumulator { + fn update_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + // Count all rows for each group, regardless of null values + for &group_idx in group_indices.iter() { + // Check filter if present + if let Some(filter) = opt_filter { + if filter.value(group_idx) { + self.counts[group_idx] += 1; + } + } else { + self.counts[group_idx] += 1; + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult>> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(vec![Arc::new(result)]) + } + + fn merge_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + let counts_array = values[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal("Expected Int64Array".to_string()) + })?; + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Check filter if present + if let Some(filter) = opt_filter { + if !filter.value(row_idx) { + continue; + } + } + + if let Some(count) = counts_array.value(row_idx).into() { + self.counts[group_idx] += count; + } + } + + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of::() + self.counts.capacity() * std::mem::size_of::() + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 7f51e489bc..5cb7dbcbe9 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -19,6 +19,7 @@ mod avg; mod avg_decimal; mod correlation; mod count_not_null; +mod count_rows; mod covariance; mod stddev; mod sum_decimal; @@ -28,6 +29,7 @@ pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; pub use count_not_null::CountNotNull; +pub use count_rows::CountRows; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 51c8951289..bdaf21a46e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -110,6 +110,10 @@ object CometCount extends CometAggregateExpressionSerde[Count] { binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { val exprChildren = expr.children.map(exprToProto(_, inputs, binding)) + + // scalastyle:off + println(exprChildren.head.getClass) + if (exprChildren.forall(_.isDefined)) { val builder = ExprOuterClass.Count.newBuilder() builder.addAllChildren(exprChildren.map(_.get).asJava) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 7ce3885b1d..6069334768 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -108,6 +108,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator( sql("SELECT _1, COUNT(_2) FROM count_null_test GROUP BY _1 ORDER BY _1")) + checkSparkAnswerAndOperator( + sql("SELECT _1, COUNT(*) FROM count_null_test GROUP BY _1 ORDER BY _1")) + // Test combined with other aggregates checkSparkAnswerAndOperator(sql("SELECT COUNT(_1), COUNT(_2) FROM count_null_test")) } From 39379dede7c66c4d4042005a675abc21d465e975 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 18:38:00 -0600 Subject: [PATCH 03/12] remove bad test --- .../org/apache/comet/serde/aggregates.scala | 4 --- .../comet/exec/CometAggregateSuite.scala | 28 ------------------- 2 files changed, 32 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index bdaf21a46e..51c8951289 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -110,10 +110,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] { binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { val exprChildren = expr.children.map(exprToProto(_, inputs, binding)) - - // scalastyle:off - println(exprChildren.head.getClass) - if (exprChildren.forall(_.isDefined)) { val builder = ExprOuterClass.Count.newBuilder() builder.addAllChildren(exprChildren.map(_.get).asJava) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 6069334768..6574d9568d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -90,34 +90,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("count with null values") { - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { - Seq(true, false).foreach { dictionaryEnabled => - withParquetTable( - Seq((1, null), (2, 42), (null, 43), (4, null), (5, 44)), - "count_null_test", - dictionaryEnabled) { - // Test COUNT on different columns (COUNT should ignore nulls) - checkSparkAnswerAndOperator(sql("SELECT COUNT(_1) FROM count_null_test")) - checkSparkAnswerAndOperator(sql("SELECT COUNT(_2) FROM count_null_test")) - - // Test with GROUP BY - checkSparkAnswerAndOperator( - sql("SELECT _1, COUNT(_2) FROM count_null_test GROUP BY _1 ORDER BY _1")) - - checkSparkAnswerAndOperator( - sql("SELECT _1, COUNT(*) FROM count_null_test GROUP BY _1 ORDER BY _1")) - - // Test combined with other aggregates - checkSparkAnswerAndOperator(sql("SELECT COUNT(_1), COUNT(_2) FROM count_null_test")) - } - } - } - } - test("lead/lag should return the default value if the offset row does not exist") { withSQLConf( CometConf.COMET_ENABLED.key -> "true", From 1822113edfbab402c1cbb5edb346fce4e2dc5bb5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 18:55:27 -0600 Subject: [PATCH 04/12] tests --- .../org/apache/comet/CometFuzzTestSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index ed250e141c..89f4c63b4b 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -192,7 +192,7 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("aggregate group by single column") { + test("count(*) group by single column") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") for (col <- df.columns) { @@ -205,6 +205,20 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("count(col) group by single column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + val groupCol = df.columns.head + for (col <- df.columns.drop(1)) { + // cannot run fully natively due to range partitioning and sort + val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + test("min/max aggregate") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") From c2daea82725522f541fd969d43999e62cf20995e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 18:56:46 -0600 Subject: [PATCH 05/12] format --- native/core/src/execution/planner.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 104a09ba9a..28fdc0ccf2 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1907,8 +1907,9 @@ impl PhysicalPlanner { let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; // Check if the child is a literal (for COUNT(*) which is COUNT(1)) - let is_literal = - matches!(child.as_any().downcast_ref::(), Some(_)); + let is_literal = child.as_any().downcast_ref::().is_some(); + + println!("is_literal = {is_literal}"); let func = if is_literal { // COUNT(*) - count all rows including nulls From 9c6e4bbd7101e61c3f66668d6dfe6dac80e78270 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 19:01:04 -0600 Subject: [PATCH 06/12] remove println --- native/core/src/execution/planner.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 28fdc0ccf2..85c4f26d06 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1909,8 +1909,6 @@ impl PhysicalPlanner { // Check if the child is a literal (for COUNT(*) which is COUNT(1)) let is_literal = child.as_any().downcast_ref::().is_some(); - println!("is_literal = {is_literal}"); - let func = if is_literal { // COUNT(*) - count all rows including nulls AggregateUDF::new_from_impl(CountRows::new()) From 64d8c096ce5b5d394fb2c881feb1dd2927d05074 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 20:19:25 -0600 Subject: [PATCH 07/12] fix regression --- spark/src/main/scala/org/apache/comet/serde/aggregates.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 51c8951289..293119c919 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -109,6 +109,10 @@ object CometCount extends CometAggregateExpressionSerde[Count] { inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + if (expr.children.length > 1) { + withInfo(expr, "Count only supported with a single argument") + return None + } val exprChildren = expr.children.map(exprToProto(_, inputs, binding)) if (exprChildren.forall(_.isDefined)) { val builder = ExprOuterClass.Count.newBuilder() From 1ab320330955e115a499e6a9afa034ba74b9d423 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 21:00:49 -0600 Subject: [PATCH 08/12] reinstate multi input count and add tests --- .../org/apache/comet/serde/aggregates.scala | 4 - .../comet/CometFuzzAggregateSuite.scala | 71 ++++++++++ .../org/apache/comet/CometFuzzTestBase.scala | 88 +++++++++++++ .../org/apache/comet/CometFuzzTestSuite.scala | 122 +----------------- 4 files changed, 160 insertions(+), 125 deletions(-) create mode 100644 spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala create mode 100644 spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 293119c919..51c8951289 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -109,10 +109,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] { inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (expr.children.length > 1) { - withInfo(expr, "Count only supported with a single argument") - return None - } val exprChildren = expr.children.map(exprToProto(_, inputs, binding)) if (exprChildren.forall(_.isDefined)) { val builder = ExprOuterClass.Count.newBuilder() diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala new file mode 100644 index 0000000000..2f1d3a9bb9 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -0,0 +1,71 @@ +package org.apache.comet + +class CometFuzzAggregateSuite extends CometFuzzTestBase { + + test("count distinct") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.columns) { + val sql = s"SELECT count(distinct $col) FROM t1" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + + test("count(*) group by single column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.columns) { + // cannot run fully natively due to range partitioning and sort + val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + + test("count(col) group by single column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + val groupCol = df.columns.head + for (col <- df.columns.drop(1)) { + // cannot run fully natively due to range partitioning and sort + val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + + test("count(col1, col2, ..) group by single column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + val groupCol = df.columns.head + val otherCol = df.columns.drop(1) + // cannot run fully natively due to range partitioning and sort + val sql = s"SELECT $groupCol, count(${otherCol.mkString(", ")}) FROM t1 " + + s"GROUP BY $groupCol ORDER BY $groupCol" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + + test("min/max aggregate") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.columns) { + // cannot run fully native due to HashAggregate + val sql = s"SELECT min($col), max($col) FROM t1" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + +} diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala new file mode 100644 index 0000000000..af37a84b1e --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala @@ -0,0 +1,88 @@ +package org.apache.comet + +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.scalactic.source.Position +import org.scalatest.Tag + +import java.io.File +import java.text.SimpleDateFormat +import scala.util.Random + +class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper { + + var filename: String = null + + /** + * We use Asia/Kathmandu because it has a non-zero number of minutes as the offset, so is an + * interesting edge case. Also, this timezone tends to be different from the default system + * timezone. + * + * Represents UTC+5:45 + */ + val defaultTimezone = "Asia/Kathmandu" + + override def beforeAll(): Unit = { + super.beforeAll() + val tempDir = System.getProperty("java.io.tmpdir") + filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet" + val random = new Random(42) + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) { + val options = + DataGenOptions( + generateArray = true, + generateStruct = true, + generateNegativeZero = false, + // override base date due to known issues with experimental scans + baseDate = + new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime) + ParquetGenerator.makeParquetFile(random, spark, filename, 1000, options) + } + } + + protected override def afterAll(): Unit = { + super.afterAll() + FileUtils.deleteDirectory(new File(filename)) + } + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + Seq("native", "jvm").foreach { shuffleMode => + Seq( + CometConf.SCAN_NATIVE_COMET, + CometConf.SCAN_NATIVE_DATAFUSION, + CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach { scanImpl => + super.test(testName + s" ($scanImpl, $shuffleMode shuffle)", testTags: _*) { + withSQLConf( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl, + CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> shuffleMode) { + testFun + } + } + } + } + } + + def collectNativeScans(plan: SparkPlan): Seq[SparkPlan] = { + collect(plan) { + case scan: CometScanExec => scan + case scan: CometNativeScanExec => scan + } + } + + def collectCometShuffleExchanges(plan: SparkPlan): Seq[SparkPlan] = { + collect(plan) { case exchange: CometShuffleExchangeExec => + exchange + } + } + +} diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index 89f4c63b4b..f4d3a2259d 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -41,43 +41,7 @@ import org.apache.spark.sql.types._ import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} -class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { - - private var filename: String = null - - /** - * We use Asia/Kathmandu because it has a non-zero number of minutes as the offset, so is an - * interesting edge case. Also, this timezone tends to be different from the default system - * timezone. - * - * Represents UTC+5:45 - */ - private val defaultTimezone = "Asia/Kathmandu" - - override def beforeAll(): Unit = { - super.beforeAll() - val tempDir = System.getProperty("java.io.tmpdir") - filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet" - val random = new Random(42) - withSQLConf( - CometConf.COMET_ENABLED.key -> "false", - SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) { - val options = - DataGenOptions( - generateArray = true, - generateStruct = true, - generateNegativeZero = false, - // override base date due to known issues with experimental scans - baseDate = - new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime) - ParquetGenerator.makeParquetFile(random, spark, filename, 1000, options) - } - } - - protected override def afterAll(): Unit = { - super.afterAll() - FileUtils.deleteDirectory(new File(filename)) - } +class CometFuzzTestSuite extends CometFuzzTestBase { test("select *") { val df = spark.read.parquet(filename) @@ -168,18 +132,6 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("count distinct") { - val df = spark.read.parquet(filename) - df.createOrReplaceTempView("t1") - for (col <- df.columns) { - val sql = s"SELECT count(distinct $col) FROM t1" - val (_, cometPlan) = checkSparkAnswer(sql) - if (usingDataSourceExec) { - assert(1 == collectNativeScans(cometPlan).length) - } - } - } - test("order by multiple columns") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") @@ -192,46 +144,6 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("count(*) group by single column") { - val df = spark.read.parquet(filename) - df.createOrReplaceTempView("t1") - for (col <- df.columns) { - // cannot run fully natively due to range partitioning and sort - val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col" - val (_, cometPlan) = checkSparkAnswer(sql) - if (usingDataSourceExec) { - assert(1 == collectNativeScans(cometPlan).length) - } - } - } - - test("count(col) group by single column") { - val df = spark.read.parquet(filename) - df.createOrReplaceTempView("t1") - val groupCol = df.columns.head - for (col <- df.columns.drop(1)) { - // cannot run fully natively due to range partitioning and sort - val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol" - val (_, cometPlan) = checkSparkAnswer(sql) - if (usingDataSourceExec) { - assert(1 == collectNativeScans(cometPlan).length) - } - } - } - - test("min/max aggregate") { - val df = spark.read.parquet(filename) - df.createOrReplaceTempView("t1") - for (col <- df.columns) { - // cannot run fully native due to HashAggregate - val sql = s"SELECT min($col), max($col) FROM t1" - val (_, cometPlan) = checkSparkAnswer(sql) - if (usingDataSourceExec) { - assert(1 == collectNativeScans(cometPlan).length) - } - } - } - test("distribute by single column (complex types)") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") @@ -385,36 +297,4 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - Seq("native", "jvm").foreach { shuffleMode => - Seq( - CometConf.SCAN_NATIVE_COMET, - CometConf.SCAN_NATIVE_DATAFUSION, - CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach { scanImpl => - super.test(testName + s" ($scanImpl, $shuffleMode shuffle)", testTags: _*) { - withSQLConf( - CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl, - CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true", - CometConf.COMET_SHUFFLE_MODE.key -> shuffleMode) { - testFun - } - } - } - } - } - - private def collectNativeScans(plan: SparkPlan): Seq[SparkPlan] = { - collect(plan) { - case scan: CometScanExec => scan - case scan: CometNativeScanExec => scan - } - } - - private def collectCometShuffleExchanges(plan: SparkPlan): Seq[SparkPlan] = { - collect(plan) { case exchange: CometShuffleExchangeExec => - exchange - } - } - } From 8aaa0d8c4004061eb7c8d6b1f40dcba88ee8f604 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 21:02:19 -0600 Subject: [PATCH 09/12] fix --- native/core/src/execution/planner.rs | 61 ++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 85c4f26d06..0c88d58741 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1903,21 +1903,58 @@ impl PhysicalPlanner { ) -> Result { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { - assert_eq!(expr.children.len(), 1); - let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; + assert!(!expr.children.is_empty()); + if expr.children.len() == 1 { + // fast path for single expression case + let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; + // Check if the child is a literal for `COUNT(1)` case + let is_literal = child.as_any().downcast_ref::().is_some(); + let func = if is_literal { + // COUNT(1) - count all rows including nulls + AggregateUDF::new_from_impl(CountRows::new()) + } else { + // COUNT(expr) - count only non-null values + AggregateUDF::new_from_impl(CountNotNull::new()) + }; + Self::create_aggr_func_expr("count", schema, vec![child], func) + } else { + // Using `count_udaf` from Comet is exceptionally slow for some reason, so + // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` + // https://github.com/apache/datafusion-comet/issues/744 - // Check if the child is a literal (for COUNT(*) which is COUNT(1)) - let is_literal = child.as_any().downcast_ref::().is_some(); + let children = expr + .children + .iter() + .map(|child| self.create_expr(child, Arc::clone(&schema))) + .collect::, _>>()?; + + // create `IS NOT NULL expr` and join them with `AND` if there are multiple + let not_null_expr: Arc = children.iter().skip(1).fold( + Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) + as Arc, + |acc, child| { + Arc::new(BinaryExpr::new( + acc, + DataFusionOperator::And, + Arc::new(IsNotNullExpr::new(Arc::clone(child))), + )) + }, + ); - let func = if is_literal { - // COUNT(*) - count all rows including nulls - AggregateUDF::new_from_impl(CountRows::new()) - } else { - // COUNT(column) - count only non-null values - AggregateUDF::new_from_impl(CountNotNull::new()) - }; + let child = Arc::new(IfExpr::new( + not_null_expr, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); - Self::create_aggr_func_expr("count", schema, vec![child], func) + AggregateExprBuilder::new(sum_udaf(), vec![child]) + .schema(schema) + .alias("count") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; From ac53707f0e5e3119ccc443811cd3911526d37a09 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 21:03:25 -0600 Subject: [PATCH 10/12] is_a --- native/core/src/execution/planner.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0c88d58741..8dd08bc070 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1908,8 +1908,7 @@ impl PhysicalPlanner { // fast path for single expression case let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; // Check if the child is a literal for `COUNT(1)` case - let is_literal = child.as_any().downcast_ref::().is_some(); - let func = if is_literal { + let func = if child.as_any().is::() { // COUNT(1) - count all rows including nulls AggregateUDF::new_from_impl(CountRows::new()) } else { From 3ad0b7e8c9119047dbbfdb5ae28bffe86927a466 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 21:05:43 -0600 Subject: [PATCH 11/12] format --- .../comet/CometFuzzAggregateSuite.scala | 19 ++++++++++ .../org/apache/comet/CometFuzzTestBase.scala | 38 +++++++++++++++---- .../org/apache/comet/CometFuzzTestSuite.scala | 13 +------ 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 2f1d3a9bb9..6c625ae053 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.comet class CometFuzzAggregateSuite extends CometFuzzTestBase { diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala index af37a84b1e..a69080e446 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestBase.scala @@ -1,19 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.comet -import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} +import java.io.File +import java.text.SimpleDateFormat + +import scala.util.Random + +import org.scalactic.source.Position +import org.scalatest.Tag + import org.apache.commons.io.FileUtils import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf -import org.scalactic.source.Position -import org.scalatest.Tag -import java.io.File -import java.text.SimpleDateFormat -import scala.util.Random +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper { @@ -54,7 +76,7 @@ class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper { } override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { + pos: Position): Unit = { Seq("native", "jvm").foreach { shuffleMode => Seq( CometConf.SCAN_NATIVE_COMET, diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index f4d3a2259d..1b65dd8d44 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -19,21 +19,10 @@ package org.apache.comet -import java.io.File -import java.text.SimpleDateFormat - import scala.util.Random -import org.scalactic.source.Position -import org.scalatest.Tag - import org.apache.commons.codec.binary.Hex -import org.apache.commons.io.FileUtils -import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types._ From 59dcace877be0248f0e64d3e273a2b5c4b31e345 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Sep 2025 21:12:29 -0600 Subject: [PATCH 12/12] add new test to CI workflow --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index c45355978e..c0cbf8bbef 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -102,6 +102,7 @@ jobs: - name: "fuzz" value: | org.apache.comet.CometFuzzTestSuite + org.apache.comet.CometFuzzAggregateSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 6c71006e5f..ea09de06f5 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -67,6 +67,7 @@ jobs: - name: "fuzz" value: | org.apache.comet.CometFuzzTestSuite + org.apache.comet.CometFuzzAggregateSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: |