diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index 6be8b28669..0c6fa8f12d 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -143,66 +143,66 @@ public byte[] copyBinaryDecimal(int i, byte[] dest) { @Override public boolean getBoolean(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public byte getByte(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public short getShort(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public int getInt(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public long getLong(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } public long getLongDecimal(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public float getFloat(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public double getDouble(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public UTF8String getUTF8String(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public byte[] getBinary(int rowId) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public ColumnarArray getArray(int i) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public ColumnarMap getMap(int i) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override public ColumnVector getChild(int i) { - throw new UnsupportedOperationException("Not yet supported"); + throw notImplementedException(); } @Override @@ -261,4 +261,9 @@ public static CometVector getVector( protected static CometVector getVector(ValueVector vector, boolean useDecimal128) { return getVector(vector, useDecimal128, null); } + + private UnsupportedOperationException notImplementedException() { + return new UnsupportedOperationException( + "CometVector subclass " + this.getClass().getName() + " does not implement this method"); + } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 6051a459e3..59222a118a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -30,6 +30,7 @@ use crate::{ use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; +use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; @@ -1904,41 +1905,58 @@ impl PhysicalPlanner { match spark_expr.expr_struct.as_ref().unwrap() { AggExprStruct::Count(expr) => { assert!(!expr.children.is_empty()); - // Using `count_udaf` from Comet is exceptionally slow for some reason, so - // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` - // https://github.com/apache/datafusion-comet/issues/744 - - let children = expr - .children - .iter() - .map(|child| self.create_expr(child, Arc::clone(&schema))) - .collect::, _>>()?; + if spark_expr.distinct { + let children = expr + .children + .iter() + .map(|child| self.create_expr(child, Arc::clone(&schema))) + .collect::, _>>()?; + + AggregateExprBuilder::new(count_udaf(), children) + .schema(schema) + .alias("count") + .with_ignore_nulls(false) + .with_distinct(true) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } else { + // Using `count_udaf` from Comet is exceptionally slow for some reason, so + // as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))` + // https://github.com/apache/datafusion-comet/issues/744 - // create `IS NOT NULL expr` and join them with `AND` if there are multiple - let not_null_expr: Arc = children.iter().skip(1).fold( - Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc, - |acc, child| { - Arc::new(BinaryExpr::new( - acc, - DataFusionOperator::And, - Arc::new(IsNotNullExpr::new(Arc::clone(child))), - )) - }, - ); + let children = expr + .children + .iter() + .map(|child| self.create_expr(child, Arc::clone(&schema))) + .collect::, _>>()?; + + // create `IS NOT NULL expr` and join them with `AND` if there are multiple + let not_null_expr: Arc = children.iter().skip(1).fold( + Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) + as Arc, + |acc, child| { + Arc::new(BinaryExpr::new( + acc, + DataFusionOperator::And, + Arc::new(IsNotNullExpr::new(Arc::clone(child))), + )) + }, + ); - let child = Arc::new(IfExpr::new( - not_null_expr, - Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), - Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), - )); + let child = Arc::new(IfExpr::new( + not_null_expr, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); - AggregateExprBuilder::new(sum_udaf(), vec![child]) - .schema(schema) - .alias("count") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + AggregateExprBuilder::new(sum_udaf(), vec![child]) + .schema(schema) + .alias("count") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 04d9376ac6..9fedbf1c28 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -90,6 +90,7 @@ message Expr { } message AggExpr { + bool distinct = 1; oneof expr_struct { Count count = 2; Sum sum = 3; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1606690eb0..11201539a6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -510,19 +510,17 @@ object QueryPlanSerde extends Logging with CometExprShim { binding: Boolean, conf: SQLConf): Option[AggExpr] = { - if (aggExpr.isDistinct) { - // https://github.com/apache/datafusion-comet/issues/1260 - withInfo(aggExpr, s"distinct aggregate not supported: $aggExpr") - return None - } - val fn = aggExpr.aggregateFunction val cometExpr = aggrSerdeMap.get(fn.getClass) cometExpr match { case Some(handler) => - handler - .asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]] - .convert(aggExpr, fn, inputs, binding, conf) + val aggSerde = handler.asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]] + if (aggExpr.isDistinct && !aggSerde.supportsDistinct) { + // https://github.com/apache/datafusion-comet/issues/1260 + withInfo(aggExpr, s"distinct aggregate not supported: $fn") + return None + } + aggSerde.convert(aggExpr, fn, inputs, binding, conf) case _ => withInfo( aggExpr, @@ -2155,6 +2153,8 @@ trait CometExpressionSerde[T <: Expression] { */ trait CometAggregateExpressionSerde[T <: AggregateFunction] { + def supportsDistinct: Boolean = false + /** * Convert a Spark expression into a protocol buffer representation that can be passed into * native code. diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 51c8951289..7f80435b55 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -103,12 +103,19 @@ object CometMax extends CometAggregateExpressionSerde[Max] { } object CometCount extends CometAggregateExpressionSerde[Count] { + + override def supportsDistinct: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Count, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + if (expr.children.length > 1) { + withInfo(aggExpr, "COUNT only supports a single argument") + return None + } val exprChildren = expr.children.map(exprToProto(_, inputs, binding)) if (exprChildren.forall(_.isDefined)) { val builder = ExprOuterClass.Count.newBuilder() @@ -116,6 +123,7 @@ object CometCount extends CometAggregateExpressionSerde[Count] { Some( ExprOuterClass.AggExpr .newBuilder() + .setDistinct(aggExpr.isDistinct) .setCount(builder) .build()) } else {