From 6ee6defef6dff91d06ed80feb0b7507539507afe Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 06:51:23 -0700 Subject: [PATCH 01/16] filter --- .../org/apache/comet/serde/CometFilter.scala | 51 +++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 18 ++----- 2 files changed, 55 insertions(+), 14 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometFilter.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala b/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala new file mode 100644 index 0000000000..1638750b5f --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.FilterExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometFilter extends CometOperatorSerde[FilterExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_FILTER_ENABLED) + + override def convert( + op: FilterExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val cond = exprToProto(op.condition, op.child.output) + + if (cond.isDefined && childOp.nonEmpty) { + val filterBuilder = OperatorOuterClass.Filter + .newBuilder() + .setPredicate(cond.get) + Some(builder.setFilter(filterBuilder).build()) + } else { + withInfo(op, op.condition, op.child) + None + } + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 63e18c145a..8ffec22d15 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -66,7 +66,10 @@ object QueryPlanSerde extends Logging with CometExprShim { * Mapping of Spark operator class to Comet operator handler. */ private val opSerdeMap: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] = - Map(classOf[ProjectExec] -> CometProject, classOf[SortExec] -> CometSort) + Map( + classOf[ProjectExec] -> CometProject, + classOf[FilterExec] -> CometFilter, + classOf[SortExec] -> CometSort) private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[ArrayAppend] -> CometArrayAppend, @@ -1065,19 +1068,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) => - val cond = exprToProto(condition, child.output) - - if (cond.isDefined && childOp.nonEmpty) { - val filterBuilder = OperatorOuterClass.Filter - .newBuilder() - .setPredicate(cond.get) - Some(builder.setFilter(filterBuilder).build()) - } else { - withInfo(op, condition, child) - None - } - case LocalLimitExec(limit, _) if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) => if (childOp.nonEmpty) { // LocalLimit doesn't use offset, but it shares same operator serde class. From 3a80a1f1bc7c953b16fd9f07a31601ffc1f4b788 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 06:56:22 -0700 Subject: [PATCH 02/16] limit --- .../apache/comet/serde/CometGlobalLimit.scala | 49 ++++++++++++++++++ .../apache/comet/serde/CometLocalLimit.scala | 50 +++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 29 +---------- 3 files changed, 101 insertions(+), 27 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala b/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala new file mode 100644 index 0000000000..774e1ad77e --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.GlobalLimitExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator + +object CometGlobalLimit extends CometOperatorSerde[GlobalLimitExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED) + + override def convert( + op: GlobalLimitExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (childOp.nonEmpty) { + val limitBuilder = OperatorOuterClass.Limit.newBuilder() + + limitBuilder.setLimit(op.limit).setOffset(op.offset) + + Some(builder.setLimit(limitBuilder).build()) + } else { + withInfo(op, "No child operator") + None + } + + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala b/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala new file mode 100644 index 0000000000..1347b12907 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.LocalLimitExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator + +object CometLocalLimit extends CometOperatorSerde[LocalLimitExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED) + + override def convert( + op: LocalLimitExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (childOp.nonEmpty) { + // LocalLimit doesn't use offset, but it shares same operator serde class. + // Just set it to zero. + val limitBuilder = OperatorOuterClass.Limit + .newBuilder() + .setLimit(op.limit) + .setOffset(0) + Some(builder.setLimit(limitBuilder).build()) + } else { + withInfo(op, "No child operator") + None + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8ffec22d15..7a4b457cdd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -69,6 +69,8 @@ object QueryPlanSerde extends Logging with CometExprShim { Map( classOf[ProjectExec] -> CometProject, classOf[FilterExec] -> CometFilter, + classOf[LocalLimitExec] -> CometLocalLimit, + classOf[GlobalLimitExec] -> CometGlobalLimit, classOf[SortExec] -> CometSort) private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( @@ -1068,33 +1070,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case LocalLimitExec(limit, _) if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) => - if (childOp.nonEmpty) { - // LocalLimit doesn't use offset, but it shares same operator serde class. - // Just set it to zero. - val limitBuilder = OperatorOuterClass.Limit - .newBuilder() - .setLimit(limit) - .setOffset(0) - Some(builder.setLimit(limitBuilder).build()) - } else { - withInfo(op, "No child operator") - None - } - - case globalLimitExec: GlobalLimitExec - if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) => - if (childOp.nonEmpty) { - val limitBuilder = OperatorOuterClass.Limit.newBuilder() - - limitBuilder.setLimit(globalLimitExec.limit).setOffset(globalLimitExec.offset) - - Some(builder.setLimit(limitBuilder).build()) - } else { - withInfo(op, "No child operator") - None - } - case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => var allProjExprs: Seq[Expression] = Seq() val projExprs = projections.flatMap(_.map(e => { From 51e0aa2b206d7f13aba4dc3fc1bf3b7e62d1fa9d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:03:03 -0700 Subject: [PATCH 03/16] hash join --- .../apache/comet/serde/CometHashJoin.scala | 102 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 67 +----------- 2 files changed, 106 insertions(+), 63 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala new file mode 100644 index 0000000000..94fd880f8f --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec} + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.{BuildSide, JoinType, Operator} +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometHashJoin extends CometOperatorSerde[HashJoin] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) + + override def convert( + join: HashJoin, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + // `HashJoin` has only two implementations in Spark, but we check the type of the join to + // make sure we are handling the correct join type. + if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(join.conf) && + join.isInstanceOf[ShuffledHashJoinExec]) && + !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(join.conf) && + join.isInstanceOf[BroadcastHashJoinExec])) { + withInfo(join, s"Invalid hash join type ${join.nodeName}") + return None + } + + if (join.buildSide == BuildRight && join.joinType == LeftAnti) { + // https://github.com/apache/datafusion-comet/issues/457 + withInfo(join, "BuildRight with LeftAnti is not supported") + return None + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + withInfo(join, cond) + return None + } + condProto.get + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => + // Spark doesn't support other join types + withInfo(join, s"Unsupported join type ${join.joinType}") + return None + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + if (leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.HashJoin + .newBuilder() + .setJoinType(joinType) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + .setBuildSide( + if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight) + condition.foreach(joinBuilder.setCondition) + Some(builder.setHashJoin(joinBuilder).build()) + } else { + val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys + withInfo(join, allExprs: _*) + None + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7a4b457cdd..e51f69adb0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -50,7 +50,7 @@ import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} import org.apache.comet.serde.Types.{DataType => ProtoDataType} import org.apache.comet.serde.Types.DataType._ @@ -71,6 +71,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[FilterExec] -> CometFilter, classOf[LocalLimitExec] -> CometLocalLimit, classOf[GlobalLimitExec] -> CometGlobalLimit, + classOf[HashJoin] -> CometHashJoin, classOf[SortExec] -> CometSort) private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( @@ -1259,66 +1260,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - case join: HashJoin => - // `HashJoin` has only two implementations in Spark, but we check the type of the join to - // make sure we are handling the correct join type. - if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) && - join.isInstanceOf[ShuffledHashJoinExec]) && - !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) && - join.isInstanceOf[BroadcastHashJoinExec])) { - withInfo(join, s"Invalid hash join type ${join.nodeName}") - return None - } - - if (join.buildSide == BuildRight && join.joinType == LeftAnti) { - // https://github.com/apache/datafusion-comet/issues/457 - withInfo(join, "BuildRight with LeftAnti is not supported") - return None - } - - val condition = join.condition.map { cond => - val condProto = exprToProto(cond, join.left.output ++ join.right.output) - if (condProto.isEmpty) { - withInfo(join, cond) - return None - } - condProto.get - } - - val joinType = join.joinType match { - case Inner => JoinType.Inner - case LeftOuter => JoinType.LeftOuter - case RightOuter => JoinType.RightOuter - case FullOuter => JoinType.FullOuter - case LeftSemi => JoinType.LeftSemi - case LeftAnti => JoinType.LeftAnti - case _ => - // Spark doesn't support other join types - withInfo(join, s"Unsupported join type ${join.joinType}") - return None - } - - val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) - val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) - - if (leftKeys.forall(_.isDefined) && - rightKeys.forall(_.isDefined) && - childOp.nonEmpty) { - val joinBuilder = OperatorOuterClass.HashJoin - .newBuilder() - .setJoinType(joinType) - .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) - .addAllRightJoinKeys(rightKeys.map(_.get).asJava) - .setBuildSide( - if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight) - condition.foreach(joinBuilder.setCondition) - Some(builder.setHashJoin(joinBuilder).build()) - } else { - val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, allExprs: _*) - None - } - case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { From 8c0338e764a8a2a5786aa89b10841d29ebdb9dae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:07:26 -0700 Subject: [PATCH 04/16] sort merge join --- .../comet/serde/CometSortMergeJoin.scala | 144 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 109 +------------ 2 files changed, 146 insertions(+), 107 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala b/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala new file mode 100644 index 0000000000..5f926f06e8 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, ExpressionSet, SortOrder} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType} + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.{JoinType, Operator} +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometSortMergeJoin extends CometOperatorSerde[SortMergeJoinExec] { + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED) + + override def convert( + join: SortMergeJoinExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. + def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + keys.map(SortOrder(_, Ascending)) + } + + def getKeyOrdering( + keys: Seq[Expression], + childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = { + val requiredOrdering = requiredOrders(keys) + if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) + } + } else { + requiredOrdering + } + } + + if (join.condition.isDefined && + !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED + .get(join.conf)) { + withInfo( + join, + s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled", + join.condition.get) + return None + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + withInfo(join, cond) + return None + } + condProto.get + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => + // Spark doesn't support other join types + withInfo(join, s"Unsupported join type ${join.joinType}") + return None + } + + // Checks if the join keys are supported by DataFusion SortMergeJoin. + val errorMsgs = join.leftKeys.flatMap { key => + if (!supportedSortMergeJoinEqualType(key.dataType)) { + Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") + } else { + None + } + } + + if (errorMsgs.nonEmpty) { + withInfo(join, errorMsgs.flatten.mkString("\n")) + return None + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering) + .map(exprToProto(_, join.left.output)) + + if (sortOptions.forall(_.isDefined) && + leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.SortMergeJoin + .newBuilder() + .setJoinType(joinType) + .addAllSortOptions(sortOptions.map(_.get).asJava) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + condition.map(joinBuilder.setCondition) + Some(builder.setSortMergeJoin(joinBuilder).build()) + } else { + val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys + withInfo(join, allExprs: _*) + None + } + + } + + /** + * Returns true if given datatype is supported as a key in DataFusion sort merge join. + */ + private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => + true + case TimestampNTZType => true + case _ => false + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e51f69adb0..62ebd96c75 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ @@ -50,7 +49,7 @@ import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} import org.apache.comet.serde.Types.{DataType => ProtoDataType} import org.apache.comet.serde.Types.DataType._ @@ -72,6 +71,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[LocalLimitExec] -> CometLocalLimit, classOf[GlobalLimitExec] -> CometGlobalLimit, classOf[HashJoin] -> CometHashJoin, + classOf[SortMergeJoinExec] -> CometSortMergeJoin, classOf[SortExec] -> CometSort) private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( @@ -917,17 +917,6 @@ object QueryPlanSerde extends Logging with CometExprShim { Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } - /** - * Returns true if given datatype is supported as a key in DataFusion sort merge join. - */ - private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => - true - case TimestampNTZType => true - case _ => false - } - /** * Convert a Spark plan operator to a protobuf Comet operator. * @@ -1260,100 +1249,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => - // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. - def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { - keys.map(SortOrder(_, Ascending)) - } - - def getKeyOrdering( - keys: Seq[Expression], - childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = { - val requiredOrdering = requiredOrders(keys) - if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { - keys.zip(childOutputOrdering).map { case (key, childOrder) => - val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key - SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) - } - } else { - requiredOrdering - } - } - - if (join.condition.isDefined && - !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED - .get(conf)) { - withInfo( - join, - s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled", - join.condition.get) - return None - } - - val condition = join.condition.map { cond => - val condProto = exprToProto(cond, join.left.output ++ join.right.output) - if (condProto.isEmpty) { - withInfo(join, cond) - return None - } - condProto.get - } - - val joinType = join.joinType match { - case Inner => JoinType.Inner - case LeftOuter => JoinType.LeftOuter - case RightOuter => JoinType.RightOuter - case FullOuter => JoinType.FullOuter - case LeftSemi => JoinType.LeftSemi - case LeftAnti => JoinType.LeftAnti - case _ => - // Spark doesn't support other join types - withInfo(op, s"Unsupported join type ${join.joinType}") - return None - } - - // Checks if the join keys are supported by DataFusion SortMergeJoin. - val errorMsgs = join.leftKeys.flatMap { key => - if (!supportedSortMergeJoinEqualType(key.dataType)) { - Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") - } else { - None - } - } - - if (errorMsgs.nonEmpty) { - withInfo(op, errorMsgs.flatten.mkString("\n")) - return None - } - - val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) - val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) - - val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering) - .map(exprToProto(_, join.left.output)) - - if (sortOptions.forall(_.isDefined) && - leftKeys.forall(_.isDefined) && - rightKeys.forall(_.isDefined) && - childOp.nonEmpty) { - val joinBuilder = OperatorOuterClass.SortMergeJoin - .newBuilder() - .setJoinType(joinType) - .addAllSortOptions(sortOptions.map(_.get).asJava) - .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) - .addAllRightJoinKeys(rightKeys.map(_.get).asJava) - condition.map(joinBuilder.setCondition) - Some(builder.setSortMergeJoin(joinBuilder).build()) - } else { - val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, allExprs: _*) - None - } - - case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => - withInfo(join, "SortMergeJoin is not enabled") - None - case op if isCometSink(op) => val supportedTypes = op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) From e30e7767f3bca008c07a635cb47f4d2a6d735e1c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:26:31 -0700 Subject: [PATCH 05/16] save --- .../apache/comet/serde/CometNativeScan.scala | 218 ++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 319 +++++------------- 2 files changed, 298 insertions(+), 239 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala new file mode 100644 index 0000000000..476313a9d1 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues +import org.apache.spark.sql.comet.CometScanExec +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.objectstore.NativeConfig +import org.apache.comet.parquet.CometParquetUtils +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} + +object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = None + + override def convert( + scan: CometScanExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() + nativeScanBuilder.setSource(scan.simpleStringWithNodeId()) + + val scanTypes = scan.output.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == scan.output.length) { + nativeScanBuilder.addAllFields(scanTypes.asJava) + + // Sink operators don't have children + builder.clearChildren() + + if (scan.conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) && + CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(scan.conf)) { + + val dataFilters = new ListBuffer[Expr]() + for (filter <- scan.dataFilters) { + exprToProto(filter, scan.output) match { + case Some(proto) => dataFilters += proto + case _ => + logWarning(s"Unsupported data filter $filter") + } + } + nativeScanBuilder.addAllDataFilters(dataFilters.asJava) + } + + val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema) + if (possibleDefaultValues.exists(_ != null)) { + // Our schema has default values. Serialize two lists, one with the default values + // and another with the indexes in the schema so the native side can map missing + // columns to these default values. + val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex + .filter { case (expr, _) => expr != null } + .map { case (expr, index) => + // ResolveDefaultColumnsUtil.getExistenceDefaultValues has evaluated these + // expressions and they should now just be literals. + (Literal(expr), index.toLong.asInstanceOf[java.lang.Long]) + } + .unzip + nativeScanBuilder.addAllDefaultValues( + defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava) + nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava) + } + + // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. + var firstPartition: Option[PartitionedFile] = None + scan.inputRDD match { + case rdd: DataSourceRDD => + val partitions = rdd.partitions + partitions.foreach(p => { + val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions + inputPartitions.foreach(partition => { + if (firstPartition.isEmpty) { + firstPartition = partition.asInstanceOf[FilePartition].files.headOption + } + partition2Proto( + partition.asInstanceOf[FilePartition], + nativeScanBuilder, + scan.relation.partitionSchema) + }) + }) + case rdd: FileScanRDD => + rdd.filePartitions.foreach(partition => { + if (firstPartition.isEmpty) { + firstPartition = partition.files.headOption + } + partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema) + }) + case _ => + } + + val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields) + val requiredSchema = schema2Proto(scan.requiredSchema.fields) + val dataSchema = schema2Proto(scan.relation.dataSchema.fields) + + val dataSchemaIndexes = scan.requiredSchema.fields.map(field => { + scan.relation.dataSchema.fieldIndex(field.name) + }) + val partitionSchemaIndexes = Array + .range( + scan.relation.dataSchema.fields.length, + scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length) + + val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx => + idx.toLong.asInstanceOf[java.lang.Long]) + + nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava) + + // In `CometScanRule`, we ensure partitionSchema is supported. + assert(partitionSchema.length == scan.relation.partitionSchema.fields.length) + + nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava) + nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava) + nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava) + nativeScanBuilder.setSessionTimezone(scan.conf.getConfString("spark.sql.session.timeZone")) + nativeScanBuilder.setCaseSensitive(scan.conf.getConf[Boolean](SQLConf.CASE_SENSITIVE)) + + // Collect S3/cloud storage configurations + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + + nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf)) + + firstPartition.foreach { partitionFile => + val objectStoreOptions = + NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) + objectStoreOptions.foreach { case (key, value) => + nativeScanBuilder.putObjectStoreOptions(key, value) + } + } + + Some(builder.setNativeScan(nativeScanBuilder).build()) + + } else { + // There are unsupported scan type + withInfo( + scan, + s"unsupported Comet operator: ${scan.nodeName}, due to unsupported data types above") + None + } + + } + + private def schema2Proto( + fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = { + val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder() + fields.map(field => { + fieldBuilder.setName(field.name) + fieldBuilder.setDataType(serializeDataType(field.dataType).get) + fieldBuilder.setNullable(field.nullable) + fieldBuilder.build() + }) + } + + private def partition2Proto( + partition: FilePartition, + nativeScanBuilder: OperatorOuterClass.NativeScan.Builder, + partitionSchema: StructType): Unit = { + val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() + partition.files.foreach(file => { + // Process the partition values + val partitionValues = file.partitionValues + assert(partitionValues.numFields == partitionSchema.length) + val partitionVals = + partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) => + val attr = partitionSchema(i) + val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty) + // In `CometScanRule`, we have already checked that all partition values are + // supported. So, we can safely use `get` here. + assert( + valueProto.isDefined, + s"Unsupported partition value: $value, type: ${attr.dataType}") + valueProto.get + } + + val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() + partitionVals.foreach(fileBuilder.addPartitionValues) + fileBuilder + .setFilePath(file.filePath.toString) + .setStart(file.start) + .setLength(file.length) + .setFileSize(file.fileSize) + partitionBuilder.addPartitionedFile(fileBuilder.build()) + }) + nativeScanBuilder.addFilePartitions(partitionBuilder.build()) + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 62ebd96c75..d2a767a797 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,7 +19,6 @@ package org.apache.comet.serde -import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.internal.Logging @@ -28,15 +27,12 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -46,8 +42,6 @@ import org.apache.spark.sql.types._ import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ -import org.apache.comet.objectstore.NativeConfig -import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} @@ -938,127 +932,7 @@ object QueryPlanSerde extends Logging with CometExprShim { // Fully native scan for V1 case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => - val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() - nativeScanBuilder.setSource(op.simpleStringWithNodeId()) - - val scanTypes = op.output.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == op.output.length) { - nativeScanBuilder.addAllFields(scanTypes.asJava) - - // Sink operators don't have children - builder.clearChildren() - - if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) && - CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(conf)) { - - val dataFilters = new ListBuffer[Expr]() - for (filter <- scan.dataFilters) { - exprToProto(filter, scan.output) match { - case Some(proto) => dataFilters += proto - case _ => - logWarning(s"Unsupported data filter $filter") - } - } - nativeScanBuilder.addAllDataFilters(dataFilters.asJava) - } - - val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema) - if (possibleDefaultValues.exists(_ != null)) { - // Our schema has default values. Serialize two lists, one with the default values - // and another with the indexes in the schema so the native side can map missing - // columns to these default values. - val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex - .filter { case (expr, _) => expr != null } - .map { case (expr, index) => - // ResolveDefaultColumnsUtil.getExistenceDefaultValues has evaluated these - // expressions and they should now just be literals. - (Literal(expr), index.toLong.asInstanceOf[java.lang.Long]) - } - .unzip - nativeScanBuilder.addAllDefaultValues( - defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava) - nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava) - } - - // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. - var firstPartition: Option[PartitionedFile] = None - scan.inputRDD match { - case rdd: DataSourceRDD => - val partitions = rdd.partitions - partitions.foreach(p => { - val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions - inputPartitions.foreach(partition => { - if (firstPartition.isEmpty) { - firstPartition = partition.asInstanceOf[FilePartition].files.headOption - } - partition2Proto( - partition.asInstanceOf[FilePartition], - nativeScanBuilder, - scan.relation.partitionSchema) - }) - }) - case rdd: FileScanRDD => - rdd.filePartitions.foreach(partition => { - if (firstPartition.isEmpty) { - firstPartition = partition.files.headOption - } - partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema) - }) - case _ => - } - - val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields) - val requiredSchema = schema2Proto(scan.requiredSchema.fields) - val dataSchema = schema2Proto(scan.relation.dataSchema.fields) - - val dataSchemaIndexes = scan.requiredSchema.fields.map(field => { - scan.relation.dataSchema.fieldIndex(field.name) - }) - val partitionSchemaIndexes = Array - .range( - scan.relation.dataSchema.fields.length, - scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length) - - val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx => - idx.toLong.asInstanceOf[java.lang.Long]) - - nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava) - - // In `CometScanRule`, we ensure partitionSchema is supported. - assert(partitionSchema.length == scan.relation.partitionSchema.fields.length) - - nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava) - nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava) - nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava) - nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone")) - nativeScanBuilder.setCaseSensitive(conf.getConf[Boolean](SQLConf.CASE_SENSITIVE)) - - // Collect S3/cloud storage configurations - val hadoopConf = scan.relation.sparkSession.sessionState - .newHadoopConfWithOptions(scan.relation.options) - - nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf)) - - firstPartition.foreach { partitionFile => - val objectStoreOptions = - NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) - objectStoreOptions.foreach { case (key, value) => - nativeScanBuilder.putObjectStoreOptions(key, value) - } - } - - Some(builder.setNativeScan(nativeScanBuilder).build()) - - } else { - // There are unsupported scan type - withInfo( - op, - s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") - None - } + CometNativeScan.convert(scan, builder, childOp: _*) case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => var allProjExprs: Seq[Expression] = Seq() @@ -1079,7 +953,7 @@ object QueryPlanSerde extends Logging with CometExprShim { } case WindowExec(windowExpression, partitionSpec, orderSpec, child) - if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => + if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => val output = child.output val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr => @@ -1123,9 +997,9 @@ object QueryPlanSerde extends Logging with CometExprShim { } case aggregate: BaseAggregateExec - if (aggregate.isInstanceOf[HashAggregateExec] || - aggregate.isInstanceOf[ObjectHashAggregateExec]) && - CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => + if (aggregate.isInstanceOf[HashAggregateExec] || + aggregate.isInstanceOf[ObjectHashAggregateExec]) && + CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => val groupingExpressions = aggregate.groupingExpressions val aggregateExpressions = aggregate.aggregateExpressions val aggregateAttributes = aggregate.aggregateAttributes @@ -1144,10 +1018,10 @@ object QueryPlanSerde extends Logging with CometExprShim { } if (groupingExpressions.exists(expr => - expr.dataType match { - case _: MapType => true - case _ => false - })) { + expr.dataType match { + case _: MapType => true + case _ => false + })) { withInfo(op, "Grouping on map types is not supported") return None } @@ -1249,60 +1123,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - case op if isCometSink(op) => - val supportedTypes = - op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) - - if (!supportedTypes) { - withInfo(op, "Unsupported data type") - return None - } - - // These operators are source of Comet native execution chain - val scanBuilder = OperatorOuterClass.Scan.newBuilder() - val source = op.simpleStringWithNodeId() - if (source.isEmpty) { - scanBuilder.setSource(op.getClass.getSimpleName) - } else { - scanBuilder.setSource(source) - } - - val ffiSafe = op match { - case _ if isExchangeSink(op) => - // Source of broadcast exchange batches is ArrowStreamReader - // Source of shuffle exchange batches is NativeBatchDecoderIterator - true - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_COMET => - // native_comet scan reuses mutable buffers - false - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT => - // native_iceberg_compat scan reuses mutable buffers for constant columns - // https://github.com/apache/datafusion-comet/issues/2152 - false - case _ => - false - } - scanBuilder.setArrowFfiSafe(ffiSafe) - - val scanTypes = op.output.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == op.output.length) { - scanBuilder.addAllFields(scanTypes.asJava) - - // Sink operators don't have children - builder.clearChildren() - - Some(builder.setScan(scanBuilder).build()) - } else { - // There are unsupported scan type - withInfo( - op, - s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") - None - } - case op => opSerdeMap.get(op.getClass) match { case Some(handler) => @@ -1312,22 +1132,86 @@ object QueryPlanSerde extends Logging with CometExprShim { op, s"Native support for operator ${op.getClass.getSimpleName} is disabled. " + s"Set ${enabledConfig.key}=true to enable it.") - return None + if (isCometSink(op)) { + return cometSink(op) + } else { + return None + } } } handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op, builder, childOp: _*) case _ => - // Emit warning if: - // 1. it is not Spark shuffle operator, which is handled separately - // 2. it is not a Comet operator - if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { - withInfo(op, s"unsupported Spark operator: ${op.nodeName}") + if (isCometSink(op)) { + cometSink(op) + } else { + // Emit warning if: + // 1. it is not Spark shuffle operator, which is handled separately + // 2. it is not a Comet operator + if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { + withInfo(op, s"unsupported Spark operator: ${op.nodeName}") + } + None } - None } } } + def cometSink(op: SparkPlan): Option[Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type") + return None + } + + // These operators are source of Comet native execution chain + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val ffiSafe = op match { + case _ if isExchangeSink(op) => + // Source of broadcast exchange batches is ArrowStreamReader + // Source of shuffle exchange batches is NativeBatchDecoderIterator + true + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_COMET => + // native_comet scan reuses mutable buffers + false + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT => + // native_iceberg_compat scan reuses mutable buffers for constant columns + // https://github.com/apache/datafusion-comet/issues/2152 + false + case _ => + false + } + scanBuilder.setArrowFfiSafe(ffiSafe) + + val scanTypes = op.output.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + + // Sink operators don't have children + builder.clearChildren() + + Some(builder.setScan(scanBuilder).build()) + } else { + // There are unsupported scan type + withInfo( + op, + s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") + None + } + + } + /** * Whether the input Spark operator `op` can be considered as a Comet sink, i.e., the start of * native execution. If it is true, we'll wrap `op` with `CometScanWrapper` or @@ -1458,49 +1342,6 @@ object QueryPlanSerde extends Logging with CometExprShim { true } - private def schema2Proto( - fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = { - val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder() - fields.map(field => { - fieldBuilder.setName(field.name) - fieldBuilder.setDataType(serializeDataType(field.dataType).get) - fieldBuilder.setNullable(field.nullable) - fieldBuilder.build() - }) - } - - private def partition2Proto( - partition: FilePartition, - nativeScanBuilder: OperatorOuterClass.NativeScan.Builder, - partitionSchema: StructType): Unit = { - val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() - partition.files.foreach(file => { - // Process the partition values - val partitionValues = file.partitionValues - assert(partitionValues.numFields == partitionSchema.length) - val partitionVals = - partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) => - val attr = partitionSchema(i) - val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty) - // In `CometScanRule`, we have already checked that all partition values are - // supported. So, we can safely use `get` here. - assert( - valueProto.isDefined, - s"Unsupported partition value: $value, type: ${attr.dataType}") - valueProto.get - } - - val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() - partitionVals.foreach(fileBuilder.addPartitionValues) - fileBuilder - .setFilePath(file.filePath.toString) - .setStart(file.start) - .setLength(file.length) - .setFileSize(file.fileSize) - partitionBuilder.addPartitionedFile(fileBuilder.build()) - }) - nativeScanBuilder.addFilePartitions(partitionBuilder.build()) - } } sealed trait SupportLevel From ac30640b052ae3af0fe7a4bb12f8b84a4432f760 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:37:49 -0700 Subject: [PATCH 06/16] agg --- .../apache/comet/serde/CometAggregate.scala | 168 ++++++++++++++++++ .../org/apache/comet/serde/CometExpand.scala | 60 +++++++ ...ometWindowExec.scala => CometWindow.scala} | 2 +- .../apache/comet/serde/QueryPlanSerde.scala | 146 +-------------- 4 files changed, 233 insertions(+), 143 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometExpand.scala rename spark/src/main/scala/org/apache/comet/serde/{CometWindowExec.scala => CometWindow.scala} (98%) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala new file mode 100644 index 0000000000..9b577d5a91 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.types.MapType + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto} + +object CometAggregate extends CometOperatorSerde[BaseAggregateExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_AGGREGATE_ENABLED) + + override def convert( + aggregate: BaseAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val groupingExpressions = aggregate.groupingExpressions + val aggregateExpressions = aggregate.aggregateExpressions + val aggregateAttributes = aggregate.aggregateAttributes + val resultExpressions = aggregate.resultExpressions + val child = aggregate.child + + if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { + withInfo(aggregate, "No group by or aggregation") + return None + } + + // Aggregate expressions with filter are not supported yet. + if (aggregateExpressions.exists(_.filter.isDefined)) { + withInfo(aggregate, "Aggregate expression with filter is not supported") + return None + } + + if (groupingExpressions.exists(expr => + expr.dataType match { + case _: MapType => true + case _ => false + })) { + withInfo(aggregate, "Grouping on map types is not supported") + return None + } + + val groupingExprsWithInput = + groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output)) + + val emptyExprs = groupingExprsWithInput.collect { + case (expr, proto) if proto.isEmpty => expr + } + + if (emptyExprs.nonEmpty) { + withInfo(aggregate, s"Unsupported group expressions: ${emptyExprs.mkString(", ")}") + return None + } + + val groupingExprs = groupingExprsWithInput.map(_._2) + + // In some of the cases, the aggregateExpressions could be empty. + // For example, if the aggregate functions only have group by or if the aggregate + // functions only have distinct aggregate functions: + // + // SELECT COUNT(distinct col2), col1 FROM test group by col1 + // +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] ) + // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36] + // +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] ) + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) + // +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ... + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) + // +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ...... + // If the aggregateExpressions is empty, we only want to build groupingExpressions, + // and skip processing of aggregateExpressions. + if (aggregateExpressions.isEmpty) { + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + Some(builder.setHashAgg(hashAggBuilder).build()) + } else { + val modes = aggregateExpressions.map(_.mode).distinct + + if (modes.size != 1) { + // This shouldn't happen as all aggregation expressions should share the same mode. + // Fallback to Spark nevertheless here. + withInfo(aggregate, "All aggregate expressions do not have the same mode") + return None + } + + val mode = modes.head match { + case Partial => CometAggregateMode.Partial + case Final => CometAggregateMode.Final + case _ => + withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") + return None + } + + // In final mode, the aggregate expressions are bound to the output of the + // child and partial aggregate expressions buffer attributes produced by partial + // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, + // we don't have to do this because we don't use the merging expression. + val binding = mode != CometAggregateMode.Final + // `output` is only used when `binding` is true (i.e., non-Final) + val output = child.output + + val aggExprs = + aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf)) + if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && + aggExprs.forall(_.isDefined)) { + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) + if (mode == CometAggregateMode.Final) { + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + } + hashAggBuilder.setModeValue(mode.getNumber) + Some(builder.setHashAgg(hashAggBuilder).build()) + } else { + val allChildren: Seq[Expression] = + groupingExpressions ++ aggregateExpressions ++ aggregateAttributes + withInfo(aggregate, allChildren: _*) + None + } + } + + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala b/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala new file mode 100644 index 0000000000..5979eed4dc --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.ExpandExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometExpand extends CometOperatorSerde[ExpandExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_EXPAND_ENABLED) + + override def convert( + op: ExpandExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + var allProjExprs: Seq[Expression] = Seq() + val projExprs = op.projections.flatMap(_.map(e => { + allProjExprs = allProjExprs :+ e + exprToProto(e, op.child.output) + })) + + if (projExprs.forall(_.isDefined) && childOp.nonEmpty) { + val expandBuilder = OperatorOuterClass.Expand + .newBuilder() + .addAllProjectList(projExprs.map(_.get).asJava) + .setNumExprPerProject(op.projections.head.size) + Some(builder.setExpand(expandBuilder).build()) + } else { + withInfo(op, allProjExprs: _*) + None + } + + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometWindowExec.scala b/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala similarity index 98% rename from spark/src/main/scala/org/apache/comet/serde/CometWindowExec.scala rename to spark/src/main/scala/org/apache/comet/serde/CometWindow.scala index dafe019421..7e963d6326 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometWindowExec.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala @@ -30,7 +30,7 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde.{exprToProto, windowExprToProto} -object CometWindowExec extends CometOperatorSerde[WindowExec] { +object CometWindow extends CometOperatorSerde[WindowExec] { override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_WINDOW_ENABLED) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b6d44496dd..8b1c10ee88 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -43,7 +43,7 @@ import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} import org.apache.comet.serde.Types.{DataType => ProtoDataType} import org.apache.comet.serde.Types.DataType._ @@ -66,7 +66,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[GlobalLimitExec] -> CometGlobalLimit, classOf[HashJoin] -> CometHashJoin, classOf[SortMergeJoinExec] -> CometSortMergeJoin, - classOf[WindowExec] -> CometWindowExec, + classOf[ExpandExec] -> CometExpand, + classOf[WindowExec] -> CometWindow, classOf[SortExec] -> CometSort) private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( @@ -935,150 +936,11 @@ object QueryPlanSerde extends Logging with CometExprShim { case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => CometNativeScan.convert(scan, builder, childOp: _*) - case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => - var allProjExprs: Seq[Expression] = Seq() - val projExprs = projections.flatMap(_.map(e => { - allProjExprs = allProjExprs :+ e - exprToProto(e, child.output) - })) - - if (projExprs.forall(_.isDefined) && childOp.nonEmpty) { - val expandBuilder = OperatorOuterClass.Expand - .newBuilder() - .addAllProjectList(projExprs.map(_.get).asJava) - .setNumExprPerProject(projections.head.size) - Some(builder.setExpand(expandBuilder).build()) - } else { - withInfo(op, allProjExprs: _*) - None - } - case aggregate: BaseAggregateExec if (aggregate.isInstanceOf[HashAggregateExec] || aggregate.isInstanceOf[ObjectHashAggregateExec]) && CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => - val groupingExpressions = aggregate.groupingExpressions - val aggregateExpressions = aggregate.aggregateExpressions - val aggregateAttributes = aggregate.aggregateAttributes - val resultExpressions = aggregate.resultExpressions - val child = aggregate.child - - if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { - withInfo(op, "No group by or aggregation") - return None - } - - // Aggregate expressions with filter are not supported yet. - if (aggregateExpressions.exists(_.filter.isDefined)) { - withInfo(op, "Aggregate expression with filter is not supported") - return None - } - - if (groupingExpressions.exists(expr => - expr.dataType match { - case _: MapType => true - case _ => false - })) { - withInfo(op, "Grouping on map types is not supported") - return None - } - - val groupingExprsWithInput = - groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output)) - - val emptyExprs = groupingExprsWithInput.collect { - case (expr, proto) if proto.isEmpty => expr - } - - if (emptyExprs.nonEmpty) { - withInfo(op, s"Unsupported group expressions: ${emptyExprs.mkString(", ")}") - return None - } - - val groupingExprs = groupingExprsWithInput.map(_._2) - - // In some of the cases, the aggregateExpressions could be empty. - // For example, if the aggregate functions only have group by or if the aggregate - // functions only have distinct aggregate functions: - // - // SELECT COUNT(distinct col2), col1 FROM test group by col1 - // +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] ) - // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36] - // +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] ) - // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) - // +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ... - // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) - // +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ...... - // If the aggregateExpressions is empty, we only want to build groupingExpressions, - // and skip processing of aggregateExpressions. - if (aggregateExpressions.isEmpty) { - val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() - hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - op, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - Some(builder.setHashAgg(hashAggBuilder).build()) - } else { - val modes = aggregateExpressions.map(_.mode).distinct - - if (modes.size != 1) { - // This shouldn't happen as all aggregation expressions should share the same mode. - // Fallback to Spark nevertheless here. - withInfo(op, "All aggregate expressions do not have the same mode") - return None - } - - val mode = modes.head match { - case Partial => CometAggregateMode.Partial - case Final => CometAggregateMode.Final - case _ => - withInfo(op, s"Unsupported aggregation mode ${modes.head}") - return None - } - - // In final mode, the aggregate expressions are bound to the output of the - // child and partial aggregate expressions buffer attributes produced by partial - // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, - // we don't have to do this because we don't use the merging expression. - val binding = mode != CometAggregateMode.Final - // `output` is only used when `binding` is true (i.e., non-Final) - val output = child.output - - val aggExprs = - aggregateExpressions.map(aggExprToProto(_, output, binding, op.conf)) - if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && - aggExprs.forall(_.isDefined)) { - val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() - hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) - if (mode == CometAggregateMode.Final) { - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - op, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - } - hashAggBuilder.setModeValue(mode.getNumber) - Some(builder.setHashAgg(hashAggBuilder).build()) - } else { - val allChildren: Seq[Expression] = - groupingExpressions ++ aggregateExpressions ++ aggregateAttributes - withInfo(op, allChildren: _*) - None - } - } + CometAggregate.convert(aggregate, builder, childOp: _*) case op => opSerdeMap.get(op.getClass) match { From b515d8fe3b779e43cf2bdb7be81d048a450cff0a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:41:53 -0700 Subject: [PATCH 07/16] docs --- docs/source/user-guide/latest/configs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1cc05dfc78..bf6388549c 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -155,7 +155,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.exec.sortMergeJoinWithJoinFilter.enabled` | Experimental support for Sort Merge Join with filter | false | | `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | | `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | -| `spark.comet.exec.window.enabled` | Whether to enable window by default. | true | +| `spark.comet.exec.window.enabled` | Whether to enable window by default. | false | ## Enabling or Disabling Individual Scalar Expressions From fcab57ecf40ed1ec50e05fa84e00313c1ffcb221 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:52:17 -0700 Subject: [PATCH 08/16] refactor --- .../apache/comet/serde/CometAggregate.scala | 37 +++++++-- .../apache/comet/serde/QueryPlanSerde.scala | 75 +++++++++---------- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala index 9b577d5a91..f0cf244f1e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala @@ -23,7 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} -import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.types.MapType import org.apache.comet.{CometConf, ConfigEntry} @@ -31,12 +31,9 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto} -object CometAggregate extends CometOperatorSerde[BaseAggregateExec] { +trait CometBaseAggregate { - override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( - CometConf.COMET_EXEC_AGGREGATE_ENABLED) - - override def convert( + def doConvert( aggregate: BaseAggregateExec, builder: Operator.Builder, childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { @@ -166,3 +163,31 @@ object CometAggregate extends CometOperatorSerde[BaseAggregateExec] { } } + +object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with CometBaseAggregate { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_AGGREGATE_ENABLED) + + override def convert( + aggregate: HashAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + doConvert(aggregate, builder, childOp: _*) + } +} + +object CometObjectHashAggregate + extends CometOperatorSerde[ObjectHashAggregateExec] + with CometBaseAggregate { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_AGGREGATE_ENABLED) + + override def convert( + aggregate: ObjectHashAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + doConvert(aggregate, builder, childOp: _*) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8b1c10ee88..a5f3805322 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec @@ -64,6 +64,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[FilterExec] -> CometFilter, classOf[LocalLimitExec] -> CometLocalLimit, classOf[GlobalLimitExec] -> CometGlobalLimit, + classOf[HashAggregateExec] -> CometHashAggregate, + classOf[ObjectHashAggregateExec] -> CometObjectHashAggregate, classOf[HashJoin] -> CometHashJoin, classOf[SortMergeJoinExec] -> CometSortMergeJoin, classOf[ExpandExec] -> CometExpand, @@ -926,52 +928,47 @@ object QueryPlanSerde extends Logging with CometExprShim { * converted to a native operator. */ def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = { - val conf = op.conf val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id) childOp.foreach(builder.addChildren) - op match { - - // Fully native scan for V1 - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => - CometNativeScan.convert(scan, builder, childOp: _*) - - case aggregate: BaseAggregateExec - if (aggregate.isInstanceOf[HashAggregateExec] || - aggregate.isInstanceOf[ObjectHashAggregateExec]) && - CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => - CometAggregate.convert(aggregate, builder, childOp: _*) + def getOperatorSerde: Option[CometOperatorSerde[_]] = { + op match { + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => + Some(CometNativeScan) + case _ => + opSerdeMap.get(op.getClass) + } + } - case op => - opSerdeMap.get(op.getClass) match { - case Some(handler) => - handler.enabledConfig.foreach { enabledConfig => - if (!enabledConfig.get(op.conf)) { - withInfo( - op, - s"Native support for operator ${op.getClass.getSimpleName} is disabled. " + - s"Set ${enabledConfig.key}=true to enable it.") - if (isCometSink(op)) { - return cometSink(op, builder) - } else { - return None - } - } - } - handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op, builder, childOp: _*) - case _ => + getOperatorSerde match { + case Some(handler) => + handler.enabledConfig.foreach { enabledConfig => + if (!enabledConfig.get(op.conf)) { + withInfo( + op, + s"Native support for operator ${op.getClass.getSimpleName} is disabled. " + + s"Set ${enabledConfig.key}=true to enable it.") if (isCometSink(op)) { - cometSink(op, builder) + return cometSink(op, builder) } else { - // Emit warning if: - // 1. it is not Spark shuffle operator, which is handled separately - // 2. it is not a Comet operator - if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { - withInfo(op, s"unsupported Spark operator: ${op.nodeName}") - } - None + return None } + } + } + handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op, builder, childOp: _*) + case _ => + if (isCometSink(op)) { + cometSink(op, builder) + } else { + // Emit warning if: + // 1. it is not Spark shuffle operator, which is handled separately + // 2. it is not a Comet operator + if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { + withInfo(op, s"unsupported Spark operator: ${op.nodeName}") + } + None } + } } From 4cde2867ff7308d2cb439d9ae442fe7308d51eaa Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 07:54:26 -0700 Subject: [PATCH 09/16] refactor --- .../serde/CometAggregateExpressionSerde.scala | 67 +++++++++ .../comet/serde/CometExpressionSerde.scala | 66 +++++++++ .../comet/serde/CometOperatorSerde.scala | 57 ++++++++ .../comet/serde/CometScalarFunction.scala | 34 +++++ .../apache/comet/serde/QueryPlanSerde.scala | 134 +----------------- 5 files changed, 226 insertions(+), 132 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala new file mode 100644 index 0000000000..c0c2b07284 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.internal.SQLConf + +/** + * Trait for providing serialization logic for aggregate expressions. + */ +trait CometAggregateExpressionSerde[T <: AggregateFunction] { + + /** + * Get a short name for the expression that can be used as part of a config key related to the + * expression, such as enabling or disabling that expression. + * + * @param expr + * The Spark expression. + * @return + * Short name for the expression, defaulting to the Spark class name + */ + def getExprConfigName(expr: T): String = expr.getClass.getSimpleName + + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param aggExpr + * The aggregate expression. + * @param expr + * The aggregate function. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @param conf + * SQLConf + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + def convert( + aggExpr: AggregateExpression, + expr: T, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala new file mode 100644 index 0000000000..20c0343037 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +/** + * Trait for providing serialization logic for expressions. + */ +trait CometExpressionSerde[T <: Expression] { + + /** + * Get a short name for the expression that can be used as part of a config key related to the + * expression, such as enabling or disabling that expression. + * + * @param expr + * The Spark expression. + * @return + * Short name for the expression, defaulting to the Spark class name + */ + def getExprConfigName(expr: T): String = expr.getClass.getSimpleName + + /** + * Determine the support level of the expression based on its attributes. + * + * @param expr + * The Spark expression. + * @return + * Support level (Compatible, Incompatible, or Unsupported). + */ + def getSupportLevel(expr: T): SupportLevel = Compatible(None) + + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param expr + * The Spark expression. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala new file mode 100644 index 0000000000..c6a95ec88a --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.comet.ConfigEntry +import org.apache.comet.serde.OperatorOuterClass.Operator + +/** + * Trait for providing serialization logic for operators. + */ +trait CometOperatorSerde[T <: SparkPlan] { + + /** + * Convert a Spark operator into a protocol buffer representation that can be passed into native + * code. + * + * @param op + * The Spark operator. + * @param builder + * The protobuf builder for the operator. + * @param childOp + * Child operators that have already been converted to Comet. + * @return + * Protocol buffer representation, or None if the operator could not be converted. In this + * case it is expected that the input operator will have been tagged with reasons why it could + * not be converted. + */ + def convert( + op: T, + builder: Operator.Builder, + childOp: Operator*): Option[OperatorOuterClass.Operator] + + /** + * Get the optional Comet configuration entry that is used to enable or disable native support + * for this operator. + */ + def enabledConfig: Option[ConfigEntry[Boolean]] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala new file mode 100644 index 0000000000..aa3bf775fb --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} + +/** Serde for scalar function. */ +case class CometScalarFunction[T <: Expression](name: String) extends CometExpressionSerde[T] { + override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProto(name, childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a5f3805322..d11ba363aa 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -39,12 +39,11 @@ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} import org.apache.comet.serde.OperatorOuterClass.Operator -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} import org.apache.comet.serde.Types.{DataType => ProtoDataType} import org.apache.comet.serde.Types.DataType._ import org.apache.comet.serde.literals.CometLiteral @@ -1079,7 +1078,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } // scalastyle:off - /** * Align w/ Arrow's * [[https://github.com/apache/arrow-rs/blob/55.2.0/arrow-ord/src/rank.rs#L30-L40 can_rank]] and @@ -1087,7 +1085,7 @@ object QueryPlanSerde extends Logging with CometExprShim { * * TODO: Include SparkSQL's [[YearMonthIntervalType]] and [[DayTimeIntervalType]] */ - // scalastyle:off + // scalastyle:on def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = { def canRank(dt: DataType): Boolean = { dt match { @@ -1147,131 +1145,3 @@ case class Incompatible(notes: Option[String] = None) extends SupportLevel /** Comet does not support this feature */ case class Unsupported(notes: Option[String] = None) extends SupportLevel - -/** - * Trait for providing serialization logic for operators. - */ -trait CometOperatorSerde[T <: SparkPlan] { - - /** - * Convert a Spark operator into a protocol buffer representation that can be passed into native - * code. - * - * @param op - * The Spark operator. - * @param builder - * The protobuf builder for the operator. - * @param childOp - * Child operators that have already been converted to Comet. - * @return - * Protocol buffer representation, or None if the operator could not be converted. In this - * case it is expected that the input operator will have been tagged with reasons why it could - * not be converted. - */ - def convert( - op: T, - builder: Operator.Builder, - childOp: Operator*): Option[OperatorOuterClass.Operator] - - /** - * Get the optional Comet configuration entry that is used to enable or disable native support - * for this operator. - */ - def enabledConfig: Option[ConfigEntry[Boolean]] -} - -/** - * Trait for providing serialization logic for expressions. - */ -trait CometExpressionSerde[T <: Expression] { - - /** - * Get a short name for the expression that can be used as part of a config key related to the - * expression, such as enabling or disabling that expression. - * - * @param expr - * The Spark expression. - * @return - * Short name for the expression, defaulting to the Spark class name - */ - def getExprConfigName(expr: T): String = expr.getClass.getSimpleName - - /** - * Determine the support level of the expression based on its attributes. - * - * @param expr - * The Spark expression. - * @return - * Support level (Compatible, Incompatible, or Unsupported). - */ - def getSupportLevel(expr: T): SupportLevel = Compatible(None) - - /** - * Convert a Spark expression into a protocol buffer representation that can be passed into - * native code. - * - * @param expr - * The Spark expression. - * @param inputs - * The input attributes. - * @param binding - * Whether the attributes are bound (this is only relevant in aggregate expressions). - * @return - * Protocol buffer representation, or None if the expression could not be converted. In this - * case it is expected that the input expression will have been tagged with reasons why it - * could not be converted. - */ - def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] -} - -/** - * Trait for providing serialization logic for aggregate expressions. - */ -trait CometAggregateExpressionSerde[T <: AggregateFunction] { - - /** - * Get a short name for the expression that can be used as part of a config key related to the - * expression, such as enabling or disabling that expression. - * - * @param expr - * The Spark expression. - * @return - * Short name for the expression, defaulting to the Spark class name - */ - def getExprConfigName(expr: T): String = expr.getClass.getSimpleName - - /** - * Convert a Spark expression into a protocol buffer representation that can be passed into - * native code. - * - * @param aggExpr - * The aggregate expression. - * @param expr - * The aggregate function. - * @param inputs - * The input attributes. - * @param binding - * Whether the attributes are bound (this is only relevant in aggregate expressions). - * @param conf - * SQLConf - * @return - * Protocol buffer representation, or None if the expression could not be converted. In this - * case it is expected that the input expression will have been tagged with reasons why it - * could not be converted. - */ - def convert( - aggExpr: AggregateExpression, - expr: T, - inputs: Seq[Attribute], - binding: Boolean, - conf: SQLConf): Option[ExprOuterClass.AggExpr] -} - -/** Serde for scalar function. */ -case class CometScalarFunction[T <: Expression](name: String) extends CometExpressionSerde[T] { - override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { - val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto(name, childExpr: _*) - optExprWithInfo(optExpr, expr, expr.children: _*) - } -} From 8afadf7add676e241ab2c3b6bc0fbd075d1ac79a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 08:02:52 -0700 Subject: [PATCH 10/16] fix --- .../apache/comet/serde/CometHashJoin.scala | 33 ++++++++++++++----- .../apache/comet/serde/QueryPlanSerde.scala | 7 ++-- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala index 94fd880f8f..937ea66571 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala @@ -31,19 +31,16 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.OperatorOuterClass.{BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.exprToProto -object CometHashJoin extends CometOperatorSerde[HashJoin] { +trait CometHashJoin { - override def enabledConfig: Option[ConfigEntry[Boolean]] = - Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) - - override def convert( - join: HashJoin, - builder: Operator.Builder, - childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + def doConvert( + join: HashJoin, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { // `HashJoin` has only two implementations in Spark, but we check the type of the join to // make sure we are handling the correct join type. if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(join.conf) && - join.isInstanceOf[ShuffledHashJoinExec]) && + join.isInstanceOf[ShuffledHashJoinExec]) && !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(join.conf) && join.isInstanceOf[BroadcastHashJoinExec])) { withInfo(join, s"Invalid hash join type ${join.nodeName}") @@ -100,3 +97,21 @@ object CometHashJoin extends CometOperatorSerde[HashJoin] { } } } + +object CometBroadcastHashJoin extends CometOperatorSerde[HashJoin] with CometHashJoin { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) + + override def convert(join: HashJoin, builder: Operator.Builder, childOp: Operator*): Option[Operator] = + doConvert(join, builder, childOp: _*) +} + +object CometShuffleHashJoin extends CometOperatorSerde[HashJoin] with CometHashJoin { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) + + override def convert(join: HashJoin, builder: Operator.Builder, childOp: Operator*): Option[Operator] = + doConvert(join, builder, childOp: _*) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d11ba363aa..e70a55e060 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -20,7 +20,6 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -34,11 +33,10 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{HashJoin, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ @@ -65,7 +63,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[GlobalLimitExec] -> CometGlobalLimit, classOf[HashAggregateExec] -> CometHashAggregate, classOf[ObjectHashAggregateExec] -> CometObjectHashAggregate, - classOf[HashJoin] -> CometHashJoin, + classOf[BroadcastHashJoinExec] -> CometBroadcastHashJoin, + classOf[ShuffledHashJoinExec] -> CometShuffleHashJoin, classOf[SortMergeJoinExec] -> CometSortMergeJoin, classOf[ExpandExec] -> CometExpand, classOf[WindowExec] -> CometWindow, From ae2254ebca2efe0d83cbadba05f05ca7bc921483 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 08:04:22 -0700 Subject: [PATCH 11/16] scalastyle --- .../org/apache/comet/serde/CometHashJoin.scala | 18 ++++++++++++------ .../apache/comet/serde/QueryPlanSerde.scala | 4 +++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala index 937ea66571..67fb67a2e7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala @@ -34,13 +34,13 @@ import org.apache.comet.serde.QueryPlanSerde.exprToProto trait CometHashJoin { def doConvert( - join: HashJoin, - builder: Operator.Builder, - childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + join: HashJoin, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { // `HashJoin` has only two implementations in Spark, but we check the type of the join to // make sure we are handling the correct join type. if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(join.conf) && - join.isInstanceOf[ShuffledHashJoinExec]) && + join.isInstanceOf[ShuffledHashJoinExec]) && !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(join.conf) && join.isInstanceOf[BroadcastHashJoinExec])) { withInfo(join, s"Invalid hash join type ${join.nodeName}") @@ -103,7 +103,10 @@ object CometBroadcastHashJoin extends CometOperatorSerde[HashJoin] with CometHas override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) - override def convert(join: HashJoin, builder: Operator.Builder, childOp: Operator*): Option[Operator] = + override def convert( + join: HashJoin, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = doConvert(join, builder, childOp: _*) } @@ -112,6 +115,9 @@ object CometShuffleHashJoin extends CometOperatorSerde[HashJoin] with CometHashJ override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) - override def convert(join: HashJoin, builder: Operator.Builder, childOp: Operator*): Option[Operator] = + override def convert( + join: HashJoin, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = doConvert(join, builder, childOp: _*) } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e70a55e060..1e63b70b81 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -20,6 +20,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -33,10 +34,11 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ From 3dbb10dbd0eed1d42c23d53e5a9286eb70bf1fb2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 08:49:46 -0700 Subject: [PATCH 12/16] remove map approach --- .../apache/comet/serde/QueryPlanSerde.scala | 209 ++++++++++-------- 1 file changed, 112 insertions(+), 97 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1e63b70b81..f512aec0f9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -58,19 +58,7 @@ object QueryPlanSerde extends Logging with CometExprShim { * Mapping of Spark operator class to Comet operator handler. */ private val opSerdeMap: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] = - Map( - classOf[ProjectExec] -> CometProject, - classOf[FilterExec] -> CometFilter, - classOf[LocalLimitExec] -> CometLocalLimit, - classOf[GlobalLimitExec] -> CometGlobalLimit, - classOf[HashAggregateExec] -> CometHashAggregate, - classOf[ObjectHashAggregateExec] -> CometObjectHashAggregate, - classOf[BroadcastHashJoinExec] -> CometBroadcastHashJoin, - classOf[ShuffledHashJoinExec] -> CometShuffleHashJoin, - classOf[SortMergeJoinExec] -> CometSortMergeJoin, - classOf[ExpandExec] -> CometExpand, - classOf[WindowExec] -> CometWindow, - classOf[SortExec] -> CometSort) + Map(classOf[ProjectExec] -> CometProject, classOf[SortExec] -> CometSort) private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[ArrayAppend] -> CometArrayAppend, @@ -928,104 +916,131 @@ object QueryPlanSerde extends Logging with CometExprShim { * converted to a native operator. */ def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = { + val conf = op.conf val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id) childOp.foreach(builder.addChildren) - def getOperatorSerde: Option[CometOperatorSerde[_]] = { - op match { - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => - Some(CometNativeScan) - case _ => - opSerdeMap.get(op.getClass) - } - } + op match { - getOperatorSerde match { - case Some(handler) => - handler.enabledConfig.foreach { enabledConfig => - if (!enabledConfig.get(op.conf)) { - withInfo( - op, - s"Native support for operator ${op.getClass.getSimpleName} is disabled. " + - s"Set ${enabledConfig.key}=true to enable it.") - if (isCometSink(op)) { - return cometSink(op, builder) - } else { - return None - } - } - } - handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op, builder, childOp: _*) - case _ => - if (isCometSink(op)) { - cometSink(op, builder) - } else { - // Emit warning if: - // 1. it is not Spark shuffle operator, which is handled separately - // 2. it is not a Comet operator - if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { - withInfo(op, s"unsupported Spark operator: ${op.nodeName}") - } - None - } + // Fully native scan for V1 + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => + CometNativeScan.convert(scan, builder, childOp: _*) - } - } + case filter: FilterExec if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) => + CometFilter.convert(filter, builder, childOp: _*) - def cometSink(op: SparkPlan, builder: Operator.Builder): Option[Operator] = { - val supportedTypes = - op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + case limit: LocalLimitExec if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) => + CometLocalLimit.convert(limit, builder, childOp: _*) - if (!supportedTypes) { - withInfo(op, "Unsupported data type") - return None - } + case globalLimitExec: GlobalLimitExec + if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) => + CometGlobalLimit.convert(globalLimitExec, builder, childOp: _*) - // These operators are source of Comet native execution chain - val scanBuilder = OperatorOuterClass.Scan.newBuilder() - val source = op.simpleStringWithNodeId() - if (source.isEmpty) { - scanBuilder.setSource(op.getClass.getSimpleName) - } else { - scanBuilder.setSource(source) - } + case expand: ExpandExec if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => + CometExpand.convert(expand, builder, childOp: _*) - val ffiSafe = op match { - case _ if isExchangeSink(op) => - // Source of broadcast exchange batches is ArrowStreamReader - // Source of shuffle exchange batches is NativeBatchDecoderIterator - true - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_COMET => - // native_comet scan reuses mutable buffers - false - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT => - // native_iceberg_compat scan reuses mutable buffers for constant columns - // https://github.com/apache/datafusion-comet/issues/2152 - false - case _ => - false - } - scanBuilder.setArrowFfiSafe(ffiSafe) + case _: WindowExec + if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => + withInfo(op, "Window expressions are not supported") + None - val scanTypes = op.output.flatten { attr => - serializeDataType(attr.dataType) - } + case aggregate: HashAggregateExec if CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => + CometHashAggregate.convert(aggregate, builder, childOp: _*) - if (scanTypes.length == op.output.length) { - scanBuilder.addAllFields(scanTypes.asJava) + case aggregate: ObjectHashAggregateExec + if CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => + CometObjectHashAggregate.convert(aggregate, builder, childOp: _*) - // Sink operators don't have children - builder.clearChildren() + case join: BroadcastHashJoinExec => + CometBroadcastHashJoin.convert(join, builder, childOp: _*) - Some(builder.setScan(scanBuilder).build()) - } else { - // There are unsupported scan type - withInfo( - op, - s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") - None - } + case join: ShuffledHashJoinExec => + CometShuffleHashJoin.convert(join, builder, childOp: _*) + case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => + CometSortMergeJoin.convert(join, builder, childOp: _*) + + case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => + withInfo(join, "SortMergeJoin is not enabled") + None + + case op if isCometSink(op) => + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type") + return None + } + + // These operators are source of Comet native execution chain + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val ffiSafe = op match { + case _ if isExchangeSink(op) => + // Source of broadcast exchange batches is ArrowStreamReader + // Source of shuffle exchange batches is NativeBatchDecoderIterator + true + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_COMET => + // native_comet scan reuses mutable buffers + false + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT => + // native_iceberg_compat scan reuses mutable buffers for constant columns + // https://github.com/apache/datafusion-comet/issues/2152 + false + case _ => + false + } + scanBuilder.setArrowFfiSafe(ffiSafe) + + val scanTypes = op.output.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + + // Sink operators don't have children + builder.clearChildren() + + Some(builder.setScan(scanBuilder).build()) + } else { + // There are unsupported scan type + withInfo( + op, + s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") + None + } + + case op => + opSerdeMap.get(op.getClass) match { + case Some(handler) => + handler.enabledConfig.foreach { enabledConfig => + if (!enabledConfig.get(op.conf)) { + withInfo( + op, + s"Native support for operator ${op.getClass.getSimpleName} is disabled. " + + s"Set ${enabledConfig.key}=true to enable it.") + return None + } + } + handler.asInstanceOf[CometOperatorSerde[SparkPlan]].convert(op, builder, childOp: _*) + case _ => + // Emit warning if: + // 1. it is not Spark shuffle operator, which is handled separately + // 2. it is not a Comet operator + if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { + withInfo(op, s"unsupported Spark operator: ${op.nodeName}") + } + None + } + } } /** From 30360ab864c41e83a7dcebfe0a2438c0cf1f7867 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 09:12:51 -0700 Subject: [PATCH 13/16] scalastyle --- common/src/main/scala/org/apache/comet/CometConf.scala | 2 +- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 7bd075eb50..60fd1940bc 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -248,7 +248,7 @@ object CometConf extends ShimCometConf { val COMET_EXEC_EXPAND_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("expand", defaultValue = true) val COMET_EXEC_WINDOW_ENABLED: ConfigEntry[Boolean] = - createExecEnabledConfig("window", defaultValue = false) + createExecEnabledConfig("window", defaultValue = true) val COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("takeOrderedAndProject", defaultValue = true) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f512aec0f9..e75038a143 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -939,8 +939,7 @@ object QueryPlanSerde extends Logging with CometExprShim { case expand: ExpandExec if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => CometExpand.convert(expand, builder, childOp: _*) - case _: WindowExec - if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => + case _: WindowExec if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => withInfo(op, "Window expressions are not supported") None From a6e9023af9cf524ca8e422c9f6dabb3b1b0559ae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 09:13:01 -0700 Subject: [PATCH 14/16] revert config --- docs/source/user-guide/latest/configs.md | 34 ++++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index bf6388549c..9490079480 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -139,23 +139,23 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| -| `spark.comet.exec.aggregate.enabled` | Whether to enable aggregate by default. | true | -| `spark.comet.exec.broadcastExchange.enabled` | Whether to enable broadcastExchange by default. | true | -| `spark.comet.exec.broadcastHashJoin.enabled` | Whether to enable broadcastHashJoin by default. | true | -| `spark.comet.exec.coalesce.enabled` | Whether to enable coalesce by default. | true | -| `spark.comet.exec.collectLimit.enabled` | Whether to enable collectLimit by default. | true | -| `spark.comet.exec.expand.enabled` | Whether to enable expand by default. | true | -| `spark.comet.exec.filter.enabled` | Whether to enable filter by default. | true | -| `spark.comet.exec.globalLimit.enabled` | Whether to enable globalLimit by default. | true | -| `spark.comet.exec.hashJoin.enabled` | Whether to enable hashJoin by default. | true | -| `spark.comet.exec.localLimit.enabled` | Whether to enable localLimit by default. | true | -| `spark.comet.exec.project.enabled` | Whether to enable project by default. | true | -| `spark.comet.exec.sort.enabled` | Whether to enable sort by default. | true | -| `spark.comet.exec.sortMergeJoin.enabled` | Whether to enable sortMergeJoin by default. | true | -| `spark.comet.exec.sortMergeJoinWithJoinFilter.enabled` | Experimental support for Sort Merge Join with filter | false | -| `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | -| `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | -| `spark.comet.exec.window.enabled` | Whether to enable window by default. | false | +| `spark.comet.exec.aggregate.enabled` | Whether to enable aggregate by default. | true | +| `spark.comet.exec.broadcastExchange.enabled` | Whether to enable broadcastExchange by default. | true | +| `spark.comet.exec.broadcastHashJoin.enabled` | Whether to enable broadcastHashJoin by default. | true | +| `spark.comet.exec.coalesce.enabled` | Whether to enable coalesce by default. | true | +| `spark.comet.exec.collectLimit.enabled` | Whether to enable collectLimit by default. | true | +| `spark.comet.exec.expand.enabled` | Whether to enable expand by default. | true | +| `spark.comet.exec.filter.enabled` | Whether to enable filter by default. | true | +| `spark.comet.exec.globalLimit.enabled` | Whether to enable globalLimit by default. | true | +| `spark.comet.exec.hashJoin.enabled` | Whether to enable hashJoin by default. | true | +| `spark.comet.exec.localLimit.enabled` | Whether to enable localLimit by default. | true | +| `spark.comet.exec.project.enabled` | Whether to enable project by default. | true | +| `spark.comet.exec.sort.enabled` | Whether to enable sort by default. | true | +| `spark.comet.exec.sortMergeJoin.enabled` | Whether to enable sortMergeJoin by default. | true | +| `spark.comet.exec.sortMergeJoinWithJoinFilter.enabled` | Experimental support for Sort Merge Join with filter | false | +| `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | +| `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | +| `spark.comet.exec.window.enabled` | Whether to enable window by default. | true | ## Enabling or Disabling Individual Scalar Expressions From c719eeba3bfb5c9e07b3d3dcd9e838d928e68894 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 09:16:04 -0700 Subject: [PATCH 15/16] revert docs --- docs/source/user-guide/latest/configs.md | 34 ++++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 9490079480..1cc05dfc78 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -139,23 +139,23 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| -| `spark.comet.exec.aggregate.enabled` | Whether to enable aggregate by default. | true | -| `spark.comet.exec.broadcastExchange.enabled` | Whether to enable broadcastExchange by default. | true | -| `spark.comet.exec.broadcastHashJoin.enabled` | Whether to enable broadcastHashJoin by default. | true | -| `spark.comet.exec.coalesce.enabled` | Whether to enable coalesce by default. | true | -| `spark.comet.exec.collectLimit.enabled` | Whether to enable collectLimit by default. | true | -| `spark.comet.exec.expand.enabled` | Whether to enable expand by default. | true | -| `spark.comet.exec.filter.enabled` | Whether to enable filter by default. | true | -| `spark.comet.exec.globalLimit.enabled` | Whether to enable globalLimit by default. | true | -| `spark.comet.exec.hashJoin.enabled` | Whether to enable hashJoin by default. | true | -| `spark.comet.exec.localLimit.enabled` | Whether to enable localLimit by default. | true | -| `spark.comet.exec.project.enabled` | Whether to enable project by default. | true | -| `spark.comet.exec.sort.enabled` | Whether to enable sort by default. | true | -| `spark.comet.exec.sortMergeJoin.enabled` | Whether to enable sortMergeJoin by default. | true | -| `spark.comet.exec.sortMergeJoinWithJoinFilter.enabled` | Experimental support for Sort Merge Join with filter | false | -| `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | -| `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | -| `spark.comet.exec.window.enabled` | Whether to enable window by default. | true | +| `spark.comet.exec.aggregate.enabled` | Whether to enable aggregate by default. | true | +| `spark.comet.exec.broadcastExchange.enabled` | Whether to enable broadcastExchange by default. | true | +| `spark.comet.exec.broadcastHashJoin.enabled` | Whether to enable broadcastHashJoin by default. | true | +| `spark.comet.exec.coalesce.enabled` | Whether to enable coalesce by default. | true | +| `spark.comet.exec.collectLimit.enabled` | Whether to enable collectLimit by default. | true | +| `spark.comet.exec.expand.enabled` | Whether to enable expand by default. | true | +| `spark.comet.exec.filter.enabled` | Whether to enable filter by default. | true | +| `spark.comet.exec.globalLimit.enabled` | Whether to enable globalLimit by default. | true | +| `spark.comet.exec.hashJoin.enabled` | Whether to enable hashJoin by default. | true | +| `spark.comet.exec.localLimit.enabled` | Whether to enable localLimit by default. | true | +| `spark.comet.exec.project.enabled` | Whether to enable project by default. | true | +| `spark.comet.exec.sort.enabled` | Whether to enable sort by default. | true | +| `spark.comet.exec.sortMergeJoin.enabled` | Whether to enable sortMergeJoin by default. | true | +| `spark.comet.exec.sortMergeJoinWithJoinFilter.enabled` | Experimental support for Sort Merge Join with filter | false | +| `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | +| `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | +| `spark.comet.exec.window.enabled` | Whether to enable window by default. | true | ## Enabling or Disabling Individual Scalar Expressions From 812dd7fc1cf8efab2c5a347687895aaa8ea7588f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 8 Nov 2025 09:21:38 -0700 Subject: [PATCH 16/16] revert a change --- .../apache/comet/serde/QueryPlanSerde.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e75038a143..3e0e837c9c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -950,18 +950,20 @@ object QueryPlanSerde extends Logging with CometExprShim { if CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => CometObjectHashAggregate.convert(aggregate, builder, childOp: _*) - case join: BroadcastHashJoinExec => + case join: BroadcastHashJoinExec + if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) => CometBroadcastHashJoin.convert(join, builder, childOp: _*) - case join: ShuffledHashJoinExec => + case join: ShuffledHashJoinExec if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) => CometShuffleHashJoin.convert(join, builder, childOp: _*) - case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => - CometSortMergeJoin.convert(join, builder, childOp: _*) - - case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => - withInfo(join, "SortMergeJoin is not enabled") - None + case join: SortMergeJoinExec => + if (CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf)) { + CometSortMergeJoin.convert(join, builder, childOp: _*) + } else { + withInfo(join, "SortMergeJoin is not enabled") + None + } case op if isCometSink(op) => val supportedTypes =