|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.comet.serde |
| 21 | + |
| 22 | +import scala.jdk.CollectionConverters._ |
| 23 | + |
| 24 | +import org.apache.spark.sql.catalyst.expressions.Expression |
| 25 | +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} |
| 26 | +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} |
| 27 | +import org.apache.spark.sql.types.MapType |
| 28 | + |
| 29 | +import org.apache.comet.{CometConf, ConfigEntry} |
| 30 | +import org.apache.comet.CometSparkSessionExtensions.withInfo |
| 31 | +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} |
| 32 | +import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto} |
| 33 | + |
| 34 | +trait CometBaseAggregate { |
| 35 | + |
| 36 | + def doConvert( |
| 37 | + aggregate: BaseAggregateExec, |
| 38 | + builder: Operator.Builder, |
| 39 | + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { |
| 40 | + val groupingExpressions = aggregate.groupingExpressions |
| 41 | + val aggregateExpressions = aggregate.aggregateExpressions |
| 42 | + val aggregateAttributes = aggregate.aggregateAttributes |
| 43 | + val resultExpressions = aggregate.resultExpressions |
| 44 | + val child = aggregate.child |
| 45 | + |
| 46 | + if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { |
| 47 | + withInfo(aggregate, "No group by or aggregation") |
| 48 | + return None |
| 49 | + } |
| 50 | + |
| 51 | + // Aggregate expressions with filter are not supported yet. |
| 52 | + if (aggregateExpressions.exists(_.filter.isDefined)) { |
| 53 | + withInfo(aggregate, "Aggregate expression with filter is not supported") |
| 54 | + return None |
| 55 | + } |
| 56 | + |
| 57 | + if (groupingExpressions.exists(expr => |
| 58 | + expr.dataType match { |
| 59 | + case _: MapType => true |
| 60 | + case _ => false |
| 61 | + })) { |
| 62 | + withInfo(aggregate, "Grouping on map types is not supported") |
| 63 | + return None |
| 64 | + } |
| 65 | + |
| 66 | + val groupingExprsWithInput = |
| 67 | + groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output)) |
| 68 | + |
| 69 | + val emptyExprs = groupingExprsWithInput.collect { |
| 70 | + case (expr, proto) if proto.isEmpty => expr |
| 71 | + } |
| 72 | + |
| 73 | + if (emptyExprs.nonEmpty) { |
| 74 | + withInfo(aggregate, s"Unsupported group expressions: ${emptyExprs.mkString(", ")}") |
| 75 | + return None |
| 76 | + } |
| 77 | + |
| 78 | + val groupingExprs = groupingExprsWithInput.map(_._2) |
| 79 | + |
| 80 | + // In some of the cases, the aggregateExpressions could be empty. |
| 81 | + // For example, if the aggregate functions only have group by or if the aggregate |
| 82 | + // functions only have distinct aggregate functions: |
| 83 | + // |
| 84 | + // SELECT COUNT(distinct col2), col1 FROM test group by col1 |
| 85 | + // +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] ) |
| 86 | + // +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36] |
| 87 | + // +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] ) |
| 88 | + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) |
| 89 | + // +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ... |
| 90 | + // +- HashAggregate (keys =[col1#6, col2#7], functions =[] ) |
| 91 | + // +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ...... |
| 92 | + // If the aggregateExpressions is empty, we only want to build groupingExpressions, |
| 93 | + // and skip processing of aggregateExpressions. |
| 94 | + if (aggregateExpressions.isEmpty) { |
| 95 | + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() |
| 96 | + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) |
| 97 | + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes |
| 98 | + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) |
| 99 | + if (resultExprs.exists(_.isEmpty)) { |
| 100 | + withInfo( |
| 101 | + aggregate, |
| 102 | + s"Unsupported result expressions found in: $resultExpressions", |
| 103 | + resultExpressions: _*) |
| 104 | + return None |
| 105 | + } |
| 106 | + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) |
| 107 | + Some(builder.setHashAgg(hashAggBuilder).build()) |
| 108 | + } else { |
| 109 | + val modes = aggregateExpressions.map(_.mode).distinct |
| 110 | + |
| 111 | + if (modes.size != 1) { |
| 112 | + // This shouldn't happen as all aggregation expressions should share the same mode. |
| 113 | + // Fallback to Spark nevertheless here. |
| 114 | + withInfo(aggregate, "All aggregate expressions do not have the same mode") |
| 115 | + return None |
| 116 | + } |
| 117 | + |
| 118 | + val mode = modes.head match { |
| 119 | + case Partial => CometAggregateMode.Partial |
| 120 | + case Final => CometAggregateMode.Final |
| 121 | + case _ => |
| 122 | + withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") |
| 123 | + return None |
| 124 | + } |
| 125 | + |
| 126 | + // In final mode, the aggregate expressions are bound to the output of the |
| 127 | + // child and partial aggregate expressions buffer attributes produced by partial |
| 128 | + // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, |
| 129 | + // we don't have to do this because we don't use the merging expression. |
| 130 | + val binding = mode != CometAggregateMode.Final |
| 131 | + // `output` is only used when `binding` is true (i.e., non-Final) |
| 132 | + val output = child.output |
| 133 | + |
| 134 | + val aggExprs = |
| 135 | + aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf)) |
| 136 | + if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && |
| 137 | + aggExprs.forall(_.isDefined)) { |
| 138 | + val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() |
| 139 | + hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) |
| 140 | + hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) |
| 141 | + if (mode == CometAggregateMode.Final) { |
| 142 | + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes |
| 143 | + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) |
| 144 | + if (resultExprs.exists(_.isEmpty)) { |
| 145 | + withInfo( |
| 146 | + aggregate, |
| 147 | + s"Unsupported result expressions found in: $resultExpressions", |
| 148 | + resultExpressions: _*) |
| 149 | + return None |
| 150 | + } |
| 151 | + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) |
| 152 | + } |
| 153 | + hashAggBuilder.setModeValue(mode.getNumber) |
| 154 | + Some(builder.setHashAgg(hashAggBuilder).build()) |
| 155 | + } else { |
| 156 | + val allChildren: Seq[Expression] = |
| 157 | + groupingExpressions ++ aggregateExpressions ++ aggregateAttributes |
| 158 | + withInfo(aggregate, allChildren: _*) |
| 159 | + None |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + } |
| 164 | + |
| 165 | +} |
| 166 | + |
| 167 | +object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with CometBaseAggregate { |
| 168 | + |
| 169 | + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( |
| 170 | + CometConf.COMET_EXEC_AGGREGATE_ENABLED) |
| 171 | + |
| 172 | + override def convert( |
| 173 | + aggregate: HashAggregateExec, |
| 174 | + builder: Operator.Builder, |
| 175 | + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { |
| 176 | + doConvert(aggregate, builder, childOp: _*) |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +object CometObjectHashAggregate |
| 181 | + extends CometOperatorSerde[ObjectHashAggregateExec] |
| 182 | + with CometBaseAggregate { |
| 183 | + |
| 184 | + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( |
| 185 | + CometConf.COMET_EXEC_AGGREGATE_ENABLED) |
| 186 | + |
| 187 | + override def convert( |
| 188 | + aggregate: ObjectHashAggregateExec, |
| 189 | + builder: Operator.Builder, |
| 190 | + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { |
| 191 | + doConvert(aggregate, builder, childOp: _*) |
| 192 | + } |
| 193 | +} |
0 commit comments