diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 6051a459e3..8dd08bc070 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -109,10 +109,10 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, - RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr, - ToJson, UnboundColumn, Variance, + ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, CountNotNull, CountRows, + Covariance, CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, + NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, + SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -1904,41 +1904,56 @@ impl PhysicalPlanner { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { assert!(!expr.children.is_empty()); - // Using `count_udaf` from Comet is exceptionally slow for some reason, so - // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` - // https://github.com/apache/datafusion-comet/issues/744 - - let children = expr - .children - .iter() - .map(|child| self.create_expr(child, Arc::clone(&schema))) - .collect::, _>>()?; + if expr.children.len() == 1 { + // fast path for single expression case + let child = self.create_expr(&expr.children[0], Arc::clone(&schema))?; + // Check if the child is a literal for `COUNT(1)` case + let func = if child.as_any().is::() { + // COUNT(1) - count all rows including nulls + AggregateUDF::new_from_impl(CountRows::new()) + } else { + // COUNT(expr) - count only non-null values + AggregateUDF::new_from_impl(CountNotNull::new()) + }; + Self::create_aggr_func_expr("count", schema, vec![child], func) + } else { + // Using `count_udaf` from Comet is exceptionally slow for some reason, so + // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` + // https://github.com/apache/datafusion-comet/issues/744 - // create `IS NOT NULL expr` and join them with `AND` if there are multiple - let not_null_expr: Arc = children.iter().skip(1).fold( - Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc, - |acc, child| { - Arc::new(BinaryExpr::new( - acc, - DataFusionOperator::And, - Arc::new(IsNotNullExpr::new(Arc::clone(child))), - )) - }, - ); + let children = expr + .children + .iter() + .map(|child| self.create_expr(child, Arc::clone(&schema))) + .collect::, _>>()?; + + // create `IS NOT NULL expr` and join them with `AND` if there are multiple + let not_null_expr: Arc = children.iter().skip(1).fold( + Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) + as Arc, + |acc, child| { + Arc::new(BinaryExpr::new( + acc, + DataFusionOperator::And, + Arc::new(IsNotNullExpr::new(Arc::clone(child))), + )) + }, + ); - let child = Arc::new(IfExpr::new( - not_null_expr, - Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), - Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), - )); + let child = Arc::new(IfExpr::new( + not_null_expr, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); - AggregateExprBuilder::new(sum_udaf(), vec![child]) - .schema(schema) - .alias("count") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + AggregateExprBuilder::new(sum_udaf(), vec![child]) + .schema(schema) + .alias("count") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; diff --git a/native/spark-expr/src/agg_funcs/count_not_null.rs b/native/spark-expr/src/agg_funcs/count_not_null.rs new file mode 100644 index 0000000000..4aea0b33c0 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/count_not_null.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, TypeSignature, Volatility, +}; + +/// CountNotNull aggregate function +/// Counts the number of non-null values in the input expression +#[derive(Debug)] +pub struct CountNotNull { + signature: Signature, +} + +impl Default for CountNotNull { + fn default() -> Self { + Self::new() + } +} + +impl CountNotNull { + pub fn new() -> Self { + Self { + signature: Signature::one_of(vec![TypeSignature::Any(1)], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CountNotNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "count_not_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(CountNotNullAccumulator::new())) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(CountNotNullGroupsAccumulator::new())) + } +} + +#[derive(Debug)] +struct CountNotNullAccumulator { + count: i64, +} + +impl CountNotNullAccumulator { + fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountNotNullAccumulator { + fn update_batch(&mut self, values: &[Arc]) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + let array = &values[0]; + let non_null_count = array.len() - array.null_count(); + self.count += non_null_count as i64; + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn merge_batch(&mut self, states: &[Arc]) -> DFResult<()> { + if states.is_empty() { + return Ok(()); + } + + let counts = states[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected Int64Array".to_string()))?; + + for i in 0..counts.len() { + if let Some(count) = counts.value(i).into() { + self.count += count; + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct CountNotNullGroupsAccumulator { + counts: Vec, +} + +impl CountNotNullGroupsAccumulator { + fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountNotNullGroupsAccumulator { + fn update_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + let array = &values[0]; + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Check filter if present + if let Some(filter) = opt_filter { + if !filter.value(row_idx) { + continue; + } + } + + // Check if value is not null + if !array.is_null(row_idx) { + self.counts[group_idx] += 1; + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult>> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(vec![Arc::new(result)]) + } + + fn merge_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + let counts_array = values[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected Int64Array".to_string()))?; + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Check filter if present + if let Some(filter) = opt_filter { + if !filter.value(row_idx) { + continue; + } + } + + if let Some(count) = counts_array.value(row_idx).into() { + self.counts[group_idx] += count; + } + } + + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of::() + self.counts.capacity() * std::mem::size_of::() + } +} diff --git a/native/spark-expr/src/agg_funcs/count_rows.rs b/native/spark-expr/src/agg_funcs/count_rows.rs new file mode 100644 index 0000000000..b0d78ef199 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/count_rows.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use datafusion::common::{Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, TypeSignature, Volatility, +}; + +/// CountRows aggregate function +/// Counts all rows including those with null values (equivalent to COUNT(*)) +#[derive(Debug)] +pub struct CountRows { + signature: Signature, +} + +impl Default for CountRows { + fn default() -> Self { + Self::new() + } +} + +impl CountRows { + pub fn new() -> Self { + Self { + signature: Signature::one_of(vec![TypeSignature::Any(1)], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CountRows { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "count_rows" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(CountRowsAccumulator::new())) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(CountRowsGroupsAccumulator::new())) + } +} + +#[derive(Debug)] +struct CountRowsAccumulator { + count: i64, +} + +impl CountRowsAccumulator { + fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountRowsAccumulator { + fn update_batch(&mut self, values: &[Arc]) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Count all rows regardless of null values + let array = &values[0]; + self.count += array.len() as i64; + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn merge_batch(&mut self, states: &[Arc]) -> DFResult<()> { + if states.is_empty() { + return Ok(()); + } + + let counts = states[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal("Expected Int64Array".to_string()) + })?; + + for i in 0..counts.len() { + if let Some(count) = counts.value(i).into() { + self.count += count; + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct CountRowsGroupsAccumulator { + counts: Vec, +} + +impl CountRowsGroupsAccumulator { + fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountRowsGroupsAccumulator { + fn update_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + // Count all rows for each group, regardless of null values + for &group_idx in group_indices.iter() { + // Check filter if present + if let Some(filter) = opt_filter { + if filter.value(group_idx) { + self.counts[group_idx] += 1; + } + } else { + self.counts[group_idx] += 1; + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult>> { + let counts = emit_to.take_needed(&mut self.counts); + let result = Int64Array::from_iter_values(counts.iter().copied()); + Ok(vec![Arc::new(result)]) + } + + fn merge_batch( + &mut self, + values: &[Arc], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + if values.is_empty() { + return Ok(()); + } + + // Resize counts if needed + if self.counts.len() < total_num_groups { + self.counts.resize(total_num_groups, 0); + } + + let counts_array = values[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal("Expected Int64Array".to_string()) + })?; + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Check filter if present + if let Some(filter) = opt_filter { + if !filter.value(row_idx) { + continue; + } + } + + if let Some(count) = counts_array.value(row_idx).into() { + self.counts[group_idx] += count; + } + } + + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of::() + self.counts.capacity() * std::mem::size_of::() + } +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 252da78890..5cb7dbcbe9 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -18,6 +18,8 @@ mod avg; mod avg_decimal; mod correlation; +mod count_not_null; +mod count_rows; mod covariance; mod stddev; mod sum_decimal; @@ -26,6 +28,8 @@ mod variance; pub use avg::Avg; pub use avg_decimal::AvgDecimal; pub use correlation::Correlation; +pub use count_not_null::CountNotNull; +pub use count_rows::CountRows; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal;