diff --git a/.gitignore b/.gitignore index 4157bf6f28..94877ced70 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ dev/dist apache-rat-*.jar venv dev/release/comet-rm/workdir +spark/benchmarks diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 6051a459e3..0e832599d9 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -30,6 +30,7 @@ use crate::{ use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; +use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; @@ -1904,35 +1905,13 @@ impl PhysicalPlanner { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { assert!(!expr.children.is_empty()); - // Using `count_udaf` from Comet is exceptionally slow for some reason, so - // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` - // https://github.com/apache/datafusion-comet/issues/744 - let children = expr .children .iter() .map(|child| self.create_expr(child, Arc::clone(&schema))) .collect::, _>>()?; - // create `IS NOT NULL expr` and join them with `AND` if there are multiple - let not_null_expr: Arc = children.iter().skip(1).fold( - Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc, - |acc, child| { - Arc::new(BinaryExpr::new( - acc, - DataFusionOperator::And, - Arc::new(IsNotNullExpr::new(Arc::clone(child))), - )) - }, - ); - - let child = Arc::new(IfExpr::new( - not_null_expr, - Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), - Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), - )); - - AggregateExprBuilder::new(sum_udaf(), vec![child]) + AggregateExprBuilder::new(count_udaf(), children) .schema(schema) .alias("count") .with_ignore_nulls(false) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala index 86b59050ec..47fbe354f5 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala @@ -64,13 +64,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ => - withSQLConf(CometConf.COMET_ENABLED.key -> "true") { - spark.sql(query).noop() - } - } - - benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -111,13 +105,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ => - withSQLConf(CometConf.COMET_ENABLED.key -> "true") { - spark.sql(query).noop() - } - } - - benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -153,15 +141,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ => - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_MEMORY_OVERHEAD.key -> "1G") { - spark.sql(query).noop() - } - } - - benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", @@ -198,13 +178,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ => - withSQLConf(CometConf.COMET_ENABLED.key -> "true") { - spark.sql(query).noop() - } - } - - benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") {