diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala new file mode 100644 index 0000000000..f0cf244f1e --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregate.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.types.MapType + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto} + +trait CometBaseAggregate { + + def doConvert( + aggregate: BaseAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val groupingExpressions = aggregate.groupingExpressions + val aggregateExpressions = aggregate.aggregateExpressions + val aggregateAttributes = aggregate.aggregateAttributes + val resultExpressions = aggregate.resultExpressions + val child = aggregate.child + + if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { + withInfo(aggregate, "No group by or aggregation") + return None + } + + // Aggregate expressions with filter are not supported yet. + if (aggregateExpressions.exists(_.filter.isDefined)) { + withInfo(aggregate, "Aggregate expression with filter is not supported") + return None + } + + if (groupingExpressions.exists(expr => + expr.dataType match { + case _: MapType => true + case _ => false + })) { + withInfo(aggregate, "Grouping on map types is not supported") + return None + } + + val groupingExprsWithInput = + groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output)) + + val emptyExprs = groupingExprsWithInput.collect { + case (expr, proto) if proto.isEmpty => expr + } + + if (emptyExprs.nonEmpty) { + withInfo(aggregate, s"Unsupported group expressions: ${emptyExprs.mkString(", ")}") + return None + } + + val groupingExprs = groupingExprsWithInput.map(_._2) + + // In some of the cases, the aggregateExpressions could be empty. + // For example, if the aggregate functions only have group by or if the aggregate + // functions only have distinct aggregate functions: + // + // SELECT COUNT(distinct col2), col1 FROM test group by col1 + // +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] ) + // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36] + // +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] ) + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) + // +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ... + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) + // +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ...... + // If the aggregateExpressions is empty, we only want to build groupingExpressions, + // and skip processing of aggregateExpressions. + if (aggregateExpressions.isEmpty) { + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + Some(builder.setHashAgg(hashAggBuilder).build()) + } else { + val modes = aggregateExpressions.map(_.mode).distinct + + if (modes.size != 1) { + // This shouldn't happen as all aggregation expressions should share the same mode. + // Fallback to Spark nevertheless here. + withInfo(aggregate, "All aggregate expressions do not have the same mode") + return None + } + + val mode = modes.head match { + case Partial => CometAggregateMode.Partial + case Final => CometAggregateMode.Final + case _ => + withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") + return None + } + + // In final mode, the aggregate expressions are bound to the output of the + // child and partial aggregate expressions buffer attributes produced by partial + // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, + // we don't have to do this because we don't use the merging expression. + val binding = mode != CometAggregateMode.Final + // `output` is only used when `binding` is true (i.e., non-Final) + val output = child.output + + val aggExprs = + aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf)) + if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && + aggExprs.forall(_.isDefined)) { + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) + if (mode == CometAggregateMode.Final) { + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + } + hashAggBuilder.setModeValue(mode.getNumber) + Some(builder.setHashAgg(hashAggBuilder).build()) + } else { + val allChildren: Seq[Expression] = + groupingExpressions ++ aggregateExpressions ++ aggregateAttributes + withInfo(aggregate, allChildren: _*) + None + } + } + + } + +} + +object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with CometBaseAggregate { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_AGGREGATE_ENABLED) + + override def convert( + aggregate: HashAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + doConvert(aggregate, builder, childOp: _*) + } +} + +object CometObjectHashAggregate + extends CometOperatorSerde[ObjectHashAggregateExec] + with CometBaseAggregate { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_AGGREGATE_ENABLED) + + override def convert( + aggregate: ObjectHashAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + doConvert(aggregate, builder, childOp: _*) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala new file mode 100644 index 0000000000..c0c2b07284 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.internal.SQLConf + +/** + * Trait for providing serialization logic for aggregate expressions. + */ +trait CometAggregateExpressionSerde[T <: AggregateFunction] { + + /** + * Get a short name for the expression that can be used as part of a config key related to the + * expression, such as enabling or disabling that expression. + * + * @param expr + * The Spark expression. + * @return + * Short name for the expression, defaulting to the Spark class name + */ + def getExprConfigName(expr: T): String = expr.getClass.getSimpleName + + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param aggExpr + * The aggregate expression. + * @param expr + * The aggregate function. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @param conf + * SQLConf + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + def convert( + aggExpr: AggregateExpression, + expr: T, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala b/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala new file mode 100644 index 0000000000..5979eed4dc --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometExpand.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.ExpandExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometExpand extends CometOperatorSerde[ExpandExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_EXPAND_ENABLED) + + override def convert( + op: ExpandExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + var allProjExprs: Seq[Expression] = Seq() + val projExprs = op.projections.flatMap(_.map(e => { + allProjExprs = allProjExprs :+ e + exprToProto(e, op.child.output) + })) + + if (projExprs.forall(_.isDefined) && childOp.nonEmpty) { + val expandBuilder = OperatorOuterClass.Expand + .newBuilder() + .addAllProjectList(projExprs.map(_.get).asJava) + .setNumExprPerProject(op.projections.head.size) + Some(builder.setExpand(expandBuilder).build()) + } else { + withInfo(op, allProjExprs: _*) + None + } + + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala new file mode 100644 index 0000000000..20c0343037 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +/** + * Trait for providing serialization logic for expressions. + */ +trait CometExpressionSerde[T <: Expression] { + + /** + * Get a short name for the expression that can be used as part of a config key related to the + * expression, such as enabling or disabling that expression. + * + * @param expr + * The Spark expression. + * @return + * Short name for the expression, defaulting to the Spark class name + */ + def getExprConfigName(expr: T): String = expr.getClass.getSimpleName + + /** + * Determine the support level of the expression based on its attributes. + * + * @param expr + * The Spark expression. + * @return + * Support level (Compatible, Incompatible, or Unsupported). + */ + def getSupportLevel(expr: T): SupportLevel = Compatible(None) + + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param expr + * The Spark expression. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala b/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala new file mode 100644 index 0000000000..1638750b5f --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometFilter.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.FilterExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometFilter extends CometOperatorSerde[FilterExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_FILTER_ENABLED) + + override def convert( + op: FilterExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val cond = exprToProto(op.condition, op.child.output) + + if (cond.isDefined && childOp.nonEmpty) { + val filterBuilder = OperatorOuterClass.Filter + .newBuilder() + .setPredicate(cond.get) + Some(builder.setFilter(filterBuilder).build()) + } else { + withInfo(op, op.condition, op.child) + None + } + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala b/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala new file mode 100644 index 0000000000..774e1ad77e --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometGlobalLimit.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.GlobalLimitExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator + +object CometGlobalLimit extends CometOperatorSerde[GlobalLimitExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED) + + override def convert( + op: GlobalLimitExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (childOp.nonEmpty) { + val limitBuilder = OperatorOuterClass.Limit.newBuilder() + + limitBuilder.setLimit(op.limit).setOffset(op.offset) + + Some(builder.setLimit(limitBuilder).build()) + } else { + withInfo(op, "No child operator") + None + } + + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala new file mode 100644 index 0000000000..67fb67a2e7 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometHashJoin.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec} + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.{BuildSide, JoinType, Operator} +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +trait CometHashJoin { + + def doConvert( + join: HashJoin, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + // `HashJoin` has only two implementations in Spark, but we check the type of the join to + // make sure we are handling the correct join type. + if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(join.conf) && + join.isInstanceOf[ShuffledHashJoinExec]) && + !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(join.conf) && + join.isInstanceOf[BroadcastHashJoinExec])) { + withInfo(join, s"Invalid hash join type ${join.nodeName}") + return None + } + + if (join.buildSide == BuildRight && join.joinType == LeftAnti) { + // https://github.com/apache/datafusion-comet/issues/457 + withInfo(join, "BuildRight with LeftAnti is not supported") + return None + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + withInfo(join, cond) + return None + } + condProto.get + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => + // Spark doesn't support other join types + withInfo(join, s"Unsupported join type ${join.joinType}") + return None + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + if (leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.HashJoin + .newBuilder() + .setJoinType(joinType) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + .setBuildSide( + if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight) + condition.foreach(joinBuilder.setCondition) + Some(builder.setHashJoin(joinBuilder).build()) + } else { + val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys + withInfo(join, allExprs: _*) + None + } + } +} + +object CometBroadcastHashJoin extends CometOperatorSerde[HashJoin] with CometHashJoin { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) + + override def convert( + join: HashJoin, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = + doConvert(join, builder, childOp: _*) +} + +object CometShuffleHashJoin extends CometOperatorSerde[HashJoin] with CometHashJoin { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_HASH_JOIN_ENABLED) + + override def convert( + join: HashJoin, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = + doConvert(join, builder, childOp: _*) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala b/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala new file mode 100644 index 0000000000..1347b12907 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometLocalLimit.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.LocalLimitExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator + +object CometLocalLimit extends CometOperatorSerde[LocalLimitExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = + Some(CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED) + + override def convert( + op: LocalLimitExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (childOp.nonEmpty) { + // LocalLimit doesn't use offset, but it shares same operator serde class. + // Just set it to zero. + val limitBuilder = OperatorOuterClass.Limit + .newBuilder() + .setLimit(op.limit) + .setOffset(0) + Some(builder.setLimit(limitBuilder).build()) + } else { + withInfo(op, "No child operator") + None + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala new file mode 100644 index 0000000000..476313a9d1 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometNativeScan.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues +import org.apache.spark.sql.comet.CometScanExec +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.objectstore.NativeConfig +import org.apache.comet.parquet.CometParquetUtils +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} + +object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = None + + override def convert( + scan: CometScanExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() + nativeScanBuilder.setSource(scan.simpleStringWithNodeId()) + + val scanTypes = scan.output.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == scan.output.length) { + nativeScanBuilder.addAllFields(scanTypes.asJava) + + // Sink operators don't have children + builder.clearChildren() + + if (scan.conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) && + CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(scan.conf)) { + + val dataFilters = new ListBuffer[Expr]() + for (filter <- scan.dataFilters) { + exprToProto(filter, scan.output) match { + case Some(proto) => dataFilters += proto + case _ => + logWarning(s"Unsupported data filter $filter") + } + } + nativeScanBuilder.addAllDataFilters(dataFilters.asJava) + } + + val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema) + if (possibleDefaultValues.exists(_ != null)) { + // Our schema has default values. Serialize two lists, one with the default values + // and another with the indexes in the schema so the native side can map missing + // columns to these default values. + val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex + .filter { case (expr, _) => expr != null } + .map { case (expr, index) => + // ResolveDefaultColumnsUtil.getExistenceDefaultValues has evaluated these + // expressions and they should now just be literals. + (Literal(expr), index.toLong.asInstanceOf[java.lang.Long]) + } + .unzip + nativeScanBuilder.addAllDefaultValues( + defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava) + nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava) + } + + // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. + var firstPartition: Option[PartitionedFile] = None + scan.inputRDD match { + case rdd: DataSourceRDD => + val partitions = rdd.partitions + partitions.foreach(p => { + val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions + inputPartitions.foreach(partition => { + if (firstPartition.isEmpty) { + firstPartition = partition.asInstanceOf[FilePartition].files.headOption + } + partition2Proto( + partition.asInstanceOf[FilePartition], + nativeScanBuilder, + scan.relation.partitionSchema) + }) + }) + case rdd: FileScanRDD => + rdd.filePartitions.foreach(partition => { + if (firstPartition.isEmpty) { + firstPartition = partition.files.headOption + } + partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema) + }) + case _ => + } + + val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields) + val requiredSchema = schema2Proto(scan.requiredSchema.fields) + val dataSchema = schema2Proto(scan.relation.dataSchema.fields) + + val dataSchemaIndexes = scan.requiredSchema.fields.map(field => { + scan.relation.dataSchema.fieldIndex(field.name) + }) + val partitionSchemaIndexes = Array + .range( + scan.relation.dataSchema.fields.length, + scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length) + + val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx => + idx.toLong.asInstanceOf[java.lang.Long]) + + nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava) + + // In `CometScanRule`, we ensure partitionSchema is supported. + assert(partitionSchema.length == scan.relation.partitionSchema.fields.length) + + nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava) + nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava) + nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava) + nativeScanBuilder.setSessionTimezone(scan.conf.getConfString("spark.sql.session.timeZone")) + nativeScanBuilder.setCaseSensitive(scan.conf.getConf[Boolean](SQLConf.CASE_SENSITIVE)) + + // Collect S3/cloud storage configurations + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + + nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf)) + + firstPartition.foreach { partitionFile => + val objectStoreOptions = + NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) + objectStoreOptions.foreach { case (key, value) => + nativeScanBuilder.putObjectStoreOptions(key, value) + } + } + + Some(builder.setNativeScan(nativeScanBuilder).build()) + + } else { + // There are unsupported scan type + withInfo( + scan, + s"unsupported Comet operator: ${scan.nodeName}, due to unsupported data types above") + None + } + + } + + private def schema2Proto( + fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = { + val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder() + fields.map(field => { + fieldBuilder.setName(field.name) + fieldBuilder.setDataType(serializeDataType(field.dataType).get) + fieldBuilder.setNullable(field.nullable) + fieldBuilder.build() + }) + } + + private def partition2Proto( + partition: FilePartition, + nativeScanBuilder: OperatorOuterClass.NativeScan.Builder, + partitionSchema: StructType): Unit = { + val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() + partition.files.foreach(file => { + // Process the partition values + val partitionValues = file.partitionValues + assert(partitionValues.numFields == partitionSchema.length) + val partitionVals = + partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) => + val attr = partitionSchema(i) + val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty) + // In `CometScanRule`, we have already checked that all partition values are + // supported. So, we can safely use `get` here. + assert( + valueProto.isDefined, + s"Unsupported partition value: $value, type: ${attr.dataType}") + valueProto.get + } + + val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() + partitionVals.foreach(fileBuilder.addPartitionValues) + fileBuilder + .setFilePath(file.filePath.toString) + .setStart(file.start) + .setLength(file.length) + .setFileSize(file.fileSize) + partitionBuilder.addPartitionedFile(fileBuilder.build()) + }) + nativeScanBuilder.addFilePartitions(partitionBuilder.build()) + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala new file mode 100644 index 0000000000..c6a95ec88a --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.comet.ConfigEntry +import org.apache.comet.serde.OperatorOuterClass.Operator + +/** + * Trait for providing serialization logic for operators. + */ +trait CometOperatorSerde[T <: SparkPlan] { + + /** + * Convert a Spark operator into a protocol buffer representation that can be passed into native + * code. + * + * @param op + * The Spark operator. + * @param builder + * The protobuf builder for the operator. + * @param childOp + * Child operators that have already been converted to Comet. + * @return + * Protocol buffer representation, or None if the operator could not be converted. In this + * case it is expected that the input operator will have been tagged with reasons why it could + * not be converted. + */ + def convert( + op: T, + builder: Operator.Builder, + childOp: Operator*): Option[OperatorOuterClass.Operator] + + /** + * Get the optional Comet configuration entry that is used to enable or disable native support + * for this operator. + */ + def enabledConfig: Option[ConfigEntry[Boolean]] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala new file mode 100644 index 0000000000..aa3bf775fb --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} + +/** Serde for scalar function. */ +case class CometScalarFunction[T <: Expression](name: String) extends CometExpressionSerde[T] { + override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProto(name, childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala b/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala new file mode 100644 index 0000000000..5f926f06e8 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometSortMergeJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, ExpressionSet, SortOrder} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType} + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.{JoinType, Operator} +import org.apache.comet.serde.QueryPlanSerde.exprToProto + +object CometSortMergeJoin extends CometOperatorSerde[SortMergeJoinExec] { + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED) + + override def convert( + join: SortMergeJoinExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. + def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + keys.map(SortOrder(_, Ascending)) + } + + def getKeyOrdering( + keys: Seq[Expression], + childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = { + val requiredOrdering = requiredOrders(keys) + if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) + } + } else { + requiredOrdering + } + } + + if (join.condition.isDefined && + !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED + .get(join.conf)) { + withInfo( + join, + s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled", + join.condition.get) + return None + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + withInfo(join, cond) + return None + } + condProto.get + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => + // Spark doesn't support other join types + withInfo(join, s"Unsupported join type ${join.joinType}") + return None + } + + // Checks if the join keys are supported by DataFusion SortMergeJoin. + val errorMsgs = join.leftKeys.flatMap { key => + if (!supportedSortMergeJoinEqualType(key.dataType)) { + Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") + } else { + None + } + } + + if (errorMsgs.nonEmpty) { + withInfo(join, errorMsgs.flatten.mkString("\n")) + return None + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering) + .map(exprToProto(_, join.left.output)) + + if (sortOptions.forall(_.isDefined) && + leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.SortMergeJoin + .newBuilder() + .setJoinType(joinType) + .addAllSortOptions(sortOptions.map(_.get).asJava) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + condition.map(joinBuilder.setCondition) + Some(builder.setSortMergeJoin(joinBuilder).build()) + } else { + val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys + withInfo(join, allExprs: _*) + None + } + + } + + /** + * Returns true if given datatype is supported as a key in DataFusion sort merge join. + */ + private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => + true + case TimestampNTZType => true + case _ => false + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala b/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala new file mode 100644 index 0000000000..7e963d6326 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometWindow.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, SortOrder, WindowExpression} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.window.WindowExec + +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, windowExprToProto} + +object CometWindow extends CometOperatorSerde[WindowExec] { + + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_WINDOW_ENABLED) + + override def convert( + op: WindowExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + val output = op.child.output + + val winExprs: Array[WindowExpression] = op.windowExpression.flatMap { expr => + expr match { + case alias: Alias => + alias.child match { + case winExpr: WindowExpression => + Some(winExpr) + case _ => + None + } + case _ => + None + } + }.toArray + + if (winExprs.length != op.windowExpression.length) { + withInfo(op, "Unsupported window expression(s)") + return None + } + + if (op.partitionSpec.nonEmpty && op.orderSpec.nonEmpty && + !validatePartitionAndSortSpecsForWindowFunc(op.partitionSpec, op.orderSpec, op)) { + return None + } + + val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf)) + val partitionExprs = op.partitionSpec.map(exprToProto(_, op.child.output)) + + val sortOrders = op.orderSpec.map(exprToProto(_, op.child.output)) + + if (windowExprProto.forall(_.isDefined) && partitionExprs.forall(_.isDefined) + && sortOrders.forall(_.isDefined)) { + val windowBuilder = OperatorOuterClass.Window.newBuilder() + windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava) + windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava) + windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava) + Some(builder.setWindow(windowBuilder).build()) + } else { + None + } + + } + + private def validatePartitionAndSortSpecsForWindowFunc( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + op: SparkPlan): Boolean = { + if (partitionSpec.length != orderSpec.length) { + return false + } + + val partitionColumnNames = partitionSpec.collect { + case a: AttributeReference => a.name + case other => + withInfo(op, s"Unsupported partition expression: ${other.getClass.getSimpleName}") + return false + } + + val orderColumnNames = orderSpec.collect { case s: SortOrder => + s.child match { + case a: AttributeReference => a.name + case other => + withInfo(op, s"Unsupported sort expression: ${other.getClass.getSimpleName}") + return false + } + } + + if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol, orderCol) => + partCol != orderCol + }) { + withInfo(op, "Partitioning and sorting specifications must be the same.") + return false + } + + true + } + +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3f8de7693c..3e0e837c9c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,39 +19,31 @@ package org.apache.comet.serde -import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} -import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ -import org.apache.comet.objectstore.NativeConfig -import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.Types.{DataType => ProtoDataType} import org.apache.comet.serde.Types.DataType._ import org.apache.comet.serde.literals.CometLiteral @@ -911,17 +903,6 @@ object QueryPlanSerde extends Logging with CometExprShim { Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } - /** - * Returns true if given datatype is supported as a key in DataFusion sort merge join. - */ - private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => - true - case TimestampNTZType => true - case _ => false - } - /** * Convert a Spark plan operator to a protobuf Comet operator. * @@ -943,514 +924,47 @@ object QueryPlanSerde extends Logging with CometExprShim { // Fully native scan for V1 case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => - val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() - nativeScanBuilder.setSource(op.simpleStringWithNodeId()) - - val scanTypes = op.output.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == op.output.length) { - nativeScanBuilder.addAllFields(scanTypes.asJava) - - // Sink operators don't have children - builder.clearChildren() - - if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED) && - CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.get(conf)) { - - val dataFilters = new ListBuffer[Expr]() - for (filter <- scan.dataFilters) { - exprToProto(filter, scan.output) match { - case Some(proto) => dataFilters += proto - case _ => - logWarning(s"Unsupported data filter $filter") - } - } - nativeScanBuilder.addAllDataFilters(dataFilters.asJava) - } - - val possibleDefaultValues = getExistenceDefaultValues(scan.requiredSchema) - if (possibleDefaultValues.exists(_ != null)) { - // Our schema has default values. Serialize two lists, one with the default values - // and another with the indexes in the schema so the native side can map missing - // columns to these default values. - val (defaultValues, indexes) = possibleDefaultValues.zipWithIndex - .filter { case (expr, _) => expr != null } - .map { case (expr, index) => - // ResolveDefaultColumnsUtil.getExistenceDefaultValues has evaluated these - // expressions and they should now just be literals. - (Literal(expr), index.toLong.asInstanceOf[java.lang.Long]) - } - .unzip - nativeScanBuilder.addAllDefaultValues( - defaultValues.flatMap(exprToProto(_, scan.output)).toIterable.asJava) - nativeScanBuilder.addAllDefaultValuesIndexes(indexes.toIterable.asJava) - } - - // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. - var firstPartition: Option[PartitionedFile] = None - scan.inputRDD match { - case rdd: DataSourceRDD => - val partitions = rdd.partitions - partitions.foreach(p => { - val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions - inputPartitions.foreach(partition => { - if (firstPartition.isEmpty) { - firstPartition = partition.asInstanceOf[FilePartition].files.headOption - } - partition2Proto( - partition.asInstanceOf[FilePartition], - nativeScanBuilder, - scan.relation.partitionSchema) - }) - }) - case rdd: FileScanRDD => - rdd.filePartitions.foreach(partition => { - if (firstPartition.isEmpty) { - firstPartition = partition.files.headOption - } - partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema) - }) - case _ => - } - - val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields) - val requiredSchema = schema2Proto(scan.requiredSchema.fields) - val dataSchema = schema2Proto(scan.relation.dataSchema.fields) - - val dataSchemaIndexes = scan.requiredSchema.fields.map(field => { - scan.relation.dataSchema.fieldIndex(field.name) - }) - val partitionSchemaIndexes = Array - .range( - scan.relation.dataSchema.fields.length, - scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length) - - val projectionVector = (dataSchemaIndexes ++ partitionSchemaIndexes).map(idx => - idx.toLong.asInstanceOf[java.lang.Long]) - - nativeScanBuilder.addAllProjectionVector(projectionVector.toIterable.asJava) - - // In `CometScanRule`, we ensure partitionSchema is supported. - assert(partitionSchema.length == scan.relation.partitionSchema.fields.length) + CometNativeScan.convert(scan, builder, childOp: _*) - nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava) - nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava) - nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava) - nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone")) - nativeScanBuilder.setCaseSensitive(conf.getConf[Boolean](SQLConf.CASE_SENSITIVE)) + case filter: FilterExec if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) => + CometFilter.convert(filter, builder, childOp: _*) - // Collect S3/cloud storage configurations - val hadoopConf = scan.relation.sparkSession.sessionState - .newHadoopConfWithOptions(scan.relation.options) - - nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf)) - - firstPartition.foreach { partitionFile => - val objectStoreOptions = - NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) - objectStoreOptions.foreach { case (key, value) => - nativeScanBuilder.putObjectStoreOptions(key, value) - } - } - - Some(builder.setNativeScan(nativeScanBuilder).build()) - - } else { - // There are unsupported scan type - withInfo( - op, - s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") - None - } - - case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) => - val cond = exprToProto(condition, child.output) - - if (cond.isDefined && childOp.nonEmpty) { - val filterBuilder = OperatorOuterClass.Filter - .newBuilder() - .setPredicate(cond.get) - Some(builder.setFilter(filterBuilder).build()) - } else { - withInfo(op, condition, child) - None - } - - case LocalLimitExec(limit, _) if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) => - if (childOp.nonEmpty) { - // LocalLimit doesn't use offset, but it shares same operator serde class. - // Just set it to zero. - val limitBuilder = OperatorOuterClass.Limit - .newBuilder() - .setLimit(limit) - .setOffset(0) - Some(builder.setLimit(limitBuilder).build()) - } else { - withInfo(op, "No child operator") - None - } + case limit: LocalLimitExec if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) => + CometLocalLimit.convert(limit, builder, childOp: _*) case globalLimitExec: GlobalLimitExec if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) => - if (childOp.nonEmpty) { - val limitBuilder = OperatorOuterClass.Limit.newBuilder() + CometGlobalLimit.convert(globalLimitExec, builder, childOp: _*) - limitBuilder.setLimit(globalLimitExec.limit).setOffset(globalLimitExec.offset) + case expand: ExpandExec if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => + CometExpand.convert(expand, builder, childOp: _*) - Some(builder.setLimit(limitBuilder).build()) - } else { - withInfo(op, "No child operator") - None - } - - case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) => - var allProjExprs: Seq[Expression] = Seq() - val projExprs = projections.flatMap(_.map(e => { - allProjExprs = allProjExprs :+ e - exprToProto(e, child.output) - })) - - if (projExprs.forall(_.isDefined) && childOp.nonEmpty) { - val expandBuilder = OperatorOuterClass.Expand - .newBuilder() - .addAllProjectList(projExprs.map(_.get).asJava) - .setNumExprPerProject(projections.head.size) - Some(builder.setExpand(expandBuilder).build()) - } else { - withInfo(op, allProjExprs: _*) - None - } - - case WindowExec(windowExpression, partitionSpec, orderSpec, child) - if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => + case _: WindowExec if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => withInfo(op, "Window expressions are not supported") None - /* - val output = child.output - - val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr => - expr match { - case alias: Alias => - alias.child match { - case winExpr: WindowExpression => - Some(winExpr) - case _ => - None - } - case _ => - None - } - }.toArray - if (winExprs.length != windowExpression.length) { - withInfo(op, "Unsupported window expression(s)") - return None - } + case aggregate: HashAggregateExec if CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => + CometHashAggregate.convert(aggregate, builder, childOp: _*) - if (partitionSpec.nonEmpty && orderSpec.nonEmpty && - !validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) { - return None - } + case aggregate: ObjectHashAggregateExec + if CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => + CometObjectHashAggregate.convert(aggregate, builder, childOp: _*) - val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf)) - val partitionExprs = partitionSpec.map(exprToProto(_, child.output)) + case join: BroadcastHashJoinExec + if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) => + CometBroadcastHashJoin.convert(join, builder, childOp: _*) - val sortOrders = orderSpec.map(exprToProto(_, child.output)) + case join: ShuffledHashJoinExec if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) => + CometShuffleHashJoin.convert(join, builder, childOp: _*) - if (windowExprProto.forall(_.isDefined) && partitionExprs.forall(_.isDefined) - && sortOrders.forall(_.isDefined)) { - val windowBuilder = OperatorOuterClass.Window.newBuilder() - windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava) - windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava) - windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava) - Some(builder.setWindow(windowBuilder).build()) + case join: SortMergeJoinExec => + if (CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf)) { + CometSortMergeJoin.convert(join, builder, childOp: _*) } else { + withInfo(join, "SortMergeJoin is not enabled") None - } */ - - case aggregate: BaseAggregateExec - if (aggregate.isInstanceOf[HashAggregateExec] || - aggregate.isInstanceOf[ObjectHashAggregateExec]) && - CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) => - val groupingExpressions = aggregate.groupingExpressions - val aggregateExpressions = aggregate.aggregateExpressions - val aggregateAttributes = aggregate.aggregateAttributes - val resultExpressions = aggregate.resultExpressions - val child = aggregate.child - - if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { - withInfo(op, "No group by or aggregation") - return None } - // Aggregate expressions with filter are not supported yet. - if (aggregateExpressions.exists(_.filter.isDefined)) { - withInfo(op, "Aggregate expression with filter is not supported") - return None - } - - if (groupingExpressions.exists(expr => - expr.dataType match { - case _: MapType => true - case _ => false - })) { - withInfo(op, "Grouping on map types is not supported") - return None - } - - val groupingExprsWithInput = - groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output)) - - val emptyExprs = groupingExprsWithInput.collect { - case (expr, proto) if proto.isEmpty => expr - } - - if (emptyExprs.nonEmpty) { - withInfo(op, s"Unsupported group expressions: ${emptyExprs.mkString(", ")}") - return None - } - - val groupingExprs = groupingExprsWithInput.map(_._2) - - // In some of the cases, the aggregateExpressions could be empty. - // For example, if the aggregate functions only have group by or if the aggregate - // functions only have distinct aggregate functions: - // - // SELECT COUNT(distinct col2), col1 FROM test group by col1 - // +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] ) - // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36] - // +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] ) - // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) - // +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ... - // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) - // +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ...... - // If the aggregateExpressions is empty, we only want to build groupingExpressions, - // and skip processing of aggregateExpressions. - if (aggregateExpressions.isEmpty) { - val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() - hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - op, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - Some(builder.setHashAgg(hashAggBuilder).build()) - } else { - val modes = aggregateExpressions.map(_.mode).distinct - - if (modes.size != 1) { - // This shouldn't happen as all aggregation expressions should share the same mode. - // Fallback to Spark nevertheless here. - withInfo(op, "All aggregate expressions do not have the same mode") - return None - } - - val mode = modes.head match { - case Partial => CometAggregateMode.Partial - case Final => CometAggregateMode.Final - case _ => - withInfo(op, s"Unsupported aggregation mode ${modes.head}") - return None - } - - // In final mode, the aggregate expressions are bound to the output of the - // child and partial aggregate expressions buffer attributes produced by partial - // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, - // we don't have to do this because we don't use the merging expression. - val binding = mode != CometAggregateMode.Final - // `output` is only used when `binding` is true (i.e., non-Final) - val output = child.output - - val aggExprs = - aggregateExpressions.map(aggExprToProto(_, output, binding, op.conf)) - if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && - aggExprs.forall(_.isDefined)) { - val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() - hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) - if (mode == CometAggregateMode.Final) { - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - op, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - } - hashAggBuilder.setModeValue(mode.getNumber) - Some(builder.setHashAgg(hashAggBuilder).build()) - } else { - val allChildren: Seq[Expression] = - groupingExpressions ++ aggregateExpressions ++ aggregateAttributes - withInfo(op, allChildren: _*) - None - } - } - - case join: HashJoin => - // `HashJoin` has only two implementations in Spark, but we check the type of the join to - // make sure we are handling the correct join type. - if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) && - join.isInstanceOf[ShuffledHashJoinExec]) && - !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) && - join.isInstanceOf[BroadcastHashJoinExec])) { - withInfo(join, s"Invalid hash join type ${join.nodeName}") - return None - } - - if (join.buildSide == BuildRight && join.joinType == LeftAnti) { - // https://github.com/apache/datafusion-comet/issues/457 - withInfo(join, "BuildRight with LeftAnti is not supported") - return None - } - - val condition = join.condition.map { cond => - val condProto = exprToProto(cond, join.left.output ++ join.right.output) - if (condProto.isEmpty) { - withInfo(join, cond) - return None - } - condProto.get - } - - val joinType = join.joinType match { - case Inner => JoinType.Inner - case LeftOuter => JoinType.LeftOuter - case RightOuter => JoinType.RightOuter - case FullOuter => JoinType.FullOuter - case LeftSemi => JoinType.LeftSemi - case LeftAnti => JoinType.LeftAnti - case _ => - // Spark doesn't support other join types - withInfo(join, s"Unsupported join type ${join.joinType}") - return None - } - - val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) - val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) - - if (leftKeys.forall(_.isDefined) && - rightKeys.forall(_.isDefined) && - childOp.nonEmpty) { - val joinBuilder = OperatorOuterClass.HashJoin - .newBuilder() - .setJoinType(joinType) - .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) - .addAllRightJoinKeys(rightKeys.map(_.get).asJava) - .setBuildSide( - if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight) - condition.foreach(joinBuilder.setCondition) - Some(builder.setHashJoin(joinBuilder).build()) - } else { - val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, allExprs: _*) - None - } - - case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => - // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec. - def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { - keys.map(SortOrder(_, Ascending)) - } - - def getKeyOrdering( - keys: Seq[Expression], - childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = { - val requiredOrdering = requiredOrders(keys) - if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { - keys.zip(childOutputOrdering).map { case (key, childOrder) => - val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key - SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) - } - } else { - requiredOrdering - } - } - - if (join.condition.isDefined && - !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED - .get(conf)) { - withInfo( - join, - s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled", - join.condition.get) - return None - } - - val condition = join.condition.map { cond => - val condProto = exprToProto(cond, join.left.output ++ join.right.output) - if (condProto.isEmpty) { - withInfo(join, cond) - return None - } - condProto.get - } - - val joinType = join.joinType match { - case Inner => JoinType.Inner - case LeftOuter => JoinType.LeftOuter - case RightOuter => JoinType.RightOuter - case FullOuter => JoinType.FullOuter - case LeftSemi => JoinType.LeftSemi - case LeftAnti => JoinType.LeftAnti - case _ => - // Spark doesn't support other join types - withInfo(op, s"Unsupported join type ${join.joinType}") - return None - } - - // Checks if the join keys are supported by DataFusion SortMergeJoin. - val errorMsgs = join.leftKeys.flatMap { key => - if (!supportedSortMergeJoinEqualType(key.dataType)) { - Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") - } else { - None - } - } - - if (errorMsgs.nonEmpty) { - withInfo(op, errorMsgs.flatten.mkString("\n")) - return None - } - - val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) - val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) - - val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering) - .map(exprToProto(_, join.left.output)) - - if (sortOptions.forall(_.isDefined) && - leftKeys.forall(_.isDefined) && - rightKeys.forall(_.isDefined) && - childOp.nonEmpty) { - val joinBuilder = OperatorOuterClass.SortMergeJoin - .newBuilder() - .setJoinType(joinType) - .addAllSortOptions(sortOptions.map(_.get).asJava) - .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) - .addAllRightJoinKeys(rightKeys.map(_.get).asJava) - condition.map(joinBuilder.setCondition) - Some(builder.setSortMergeJoin(joinBuilder).build()) - } else { - val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys - withInfo(join, allExprs: _*) - None - } - - case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) => - withInfo(join, "SortMergeJoin is not enabled") - None - case op if isCometSink(op) => val supportedTypes = op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) @@ -1581,7 +1095,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } // scalastyle:off - /** * Align w/ Arrow's * [[https://github.com/apache/arrow-rs/blob/55.2.0/arrow-ord/src/rank.rs#L30-L40 can_rank]] and @@ -1589,7 +1102,7 @@ object QueryPlanSerde extends Logging with CometExprShim { * * TODO: Include SparkSQL's [[YearMonthIntervalType]] and [[DayTimeIntervalType]] */ - // scalastyle:off + // scalastyle:on def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = { def canRank(dt: DataType): Boolean = { dt match { @@ -1626,83 +1139,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - private def validatePartitionAndSortSpecsForWindowFunc( - partitionSpec: Seq[Expression], - orderSpec: Seq[SortOrder], - op: SparkPlan): Boolean = { - if (partitionSpec.length != orderSpec.length) { - return false - } - - val partitionColumnNames = partitionSpec.collect { - case a: AttributeReference => a.name - case other => - withInfo(op, s"Unsupported partition expression: ${other.getClass.getSimpleName}") - return false - } - - val orderColumnNames = orderSpec.collect { case s: SortOrder => - s.child match { - case a: AttributeReference => a.name - case other => - withInfo(op, s"Unsupported sort expression: ${other.getClass.getSimpleName}") - return false - } - } - - if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol, orderCol) => - partCol != orderCol - }) { - withInfo(op, "Partitioning and sorting specifications must be the same.") - return false - } - - true - } - - private def schema2Proto( - fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = { - val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder() - fields.map(field => { - fieldBuilder.setName(field.name) - fieldBuilder.setDataType(serializeDataType(field.dataType).get) - fieldBuilder.setNullable(field.nullable) - fieldBuilder.build() - }) - } - - private def partition2Proto( - partition: FilePartition, - nativeScanBuilder: OperatorOuterClass.NativeScan.Builder, - partitionSchema: StructType): Unit = { - val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder() - partition.files.foreach(file => { - // Process the partition values - val partitionValues = file.partitionValues - assert(partitionValues.numFields == partitionSchema.length) - val partitionVals = - partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) => - val attr = partitionSchema(i) - val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty) - // In `CometScanRule`, we have already checked that all partition values are - // supported. So, we can safely use `get` here. - assert( - valueProto.isDefined, - s"Unsupported partition value: $value, type: ${attr.dataType}") - valueProto.get - } - - val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder() - partitionVals.foreach(fileBuilder.addPartitionValues) - fileBuilder - .setFilePath(file.filePath.toString) - .setStart(file.start) - .setLength(file.length) - .setFileSize(file.fileSize) - partitionBuilder.addPartitionedFile(fileBuilder.build()) - }) - nativeScanBuilder.addFilePartitions(partitionBuilder.build()) - } } sealed trait SupportLevel @@ -1726,131 +1162,3 @@ case class Incompatible(notes: Option[String] = None) extends SupportLevel /** Comet does not support this feature */ case class Unsupported(notes: Option[String] = None) extends SupportLevel - -/** - * Trait for providing serialization logic for operators. - */ -trait CometOperatorSerde[T <: SparkPlan] { - - /** - * Convert a Spark operator into a protocol buffer representation that can be passed into native - * code. - * - * @param op - * The Spark operator. - * @param builder - * The protobuf builder for the operator. - * @param childOp - * Child operators that have already been converted to Comet. - * @return - * Protocol buffer representation, or None if the operator could not be converted. In this - * case it is expected that the input operator will have been tagged with reasons why it could - * not be converted. - */ - def convert( - op: T, - builder: Operator.Builder, - childOp: Operator*): Option[OperatorOuterClass.Operator] - - /** - * Get the optional Comet configuration entry that is used to enable or disable native support - * for this operator. - */ - def enabledConfig: Option[ConfigEntry[Boolean]] -} - -/** - * Trait for providing serialization logic for expressions. - */ -trait CometExpressionSerde[T <: Expression] { - - /** - * Get a short name for the expression that can be used as part of a config key related to the - * expression, such as enabling or disabling that expression. - * - * @param expr - * The Spark expression. - * @return - * Short name for the expression, defaulting to the Spark class name - */ - def getExprConfigName(expr: T): String = expr.getClass.getSimpleName - - /** - * Determine the support level of the expression based on its attributes. - * - * @param expr - * The Spark expression. - * @return - * Support level (Compatible, Incompatible, or Unsupported). - */ - def getSupportLevel(expr: T): SupportLevel = Compatible(None) - - /** - * Convert a Spark expression into a protocol buffer representation that can be passed into - * native code. - * - * @param expr - * The Spark expression. - * @param inputs - * The input attributes. - * @param binding - * Whether the attributes are bound (this is only relevant in aggregate expressions). - * @return - * Protocol buffer representation, or None if the expression could not be converted. In this - * case it is expected that the input expression will have been tagged with reasons why it - * could not be converted. - */ - def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] -} - -/** - * Trait for providing serialization logic for aggregate expressions. - */ -trait CometAggregateExpressionSerde[T <: AggregateFunction] { - - /** - * Get a short name for the expression that can be used as part of a config key related to the - * expression, such as enabling or disabling that expression. - * - * @param expr - * The Spark expression. - * @return - * Short name for the expression, defaulting to the Spark class name - */ - def getExprConfigName(expr: T): String = expr.getClass.getSimpleName - - /** - * Convert a Spark expression into a protocol buffer representation that can be passed into - * native code. - * - * @param aggExpr - * The aggregate expression. - * @param expr - * The aggregate function. - * @param inputs - * The input attributes. - * @param binding - * Whether the attributes are bound (this is only relevant in aggregate expressions). - * @param conf - * SQLConf - * @return - * Protocol buffer representation, or None if the expression could not be converted. In this - * case it is expected that the input expression will have been tagged with reasons why it - * could not be converted. - */ - def convert( - aggExpr: AggregateExpression, - expr: T, - inputs: Seq[Attribute], - binding: Boolean, - conf: SQLConf): Option[ExprOuterClass.AggExpr] -} - -/** Serde for scalar function. */ -case class CometScalarFunction[T <: Expression](name: String) extends CometExpressionSerde[T] { - override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { - val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto(name, childExpr: _*) - optExprWithInfo(optExpr, expr, expr.children: _*) - } -}