Skip to content

Commit 9bfb5c4

Browse files
authored
chore: Refactor operator serde - part 1 (#2738)
1 parent 6b340d3 commit 9bfb5c4

14 files changed

+1261
-721
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import scala.jdk.CollectionConverters._
23+
24+
import org.apache.spark.sql.catalyst.expressions.Expression
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
26+
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
27+
import org.apache.spark.sql.types.MapType
28+
29+
import org.apache.comet.{CometConf, ConfigEntry}
30+
import org.apache.comet.CometSparkSessionExtensions.withInfo
31+
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
32+
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
33+
34+
trait CometBaseAggregate {
35+
36+
def doConvert(
37+
aggregate: BaseAggregateExec,
38+
builder: Operator.Builder,
39+
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
40+
val groupingExpressions = aggregate.groupingExpressions
41+
val aggregateExpressions = aggregate.aggregateExpressions
42+
val aggregateAttributes = aggregate.aggregateAttributes
43+
val resultExpressions = aggregate.resultExpressions
44+
val child = aggregate.child
45+
46+
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
47+
withInfo(aggregate, "No group by or aggregation")
48+
return None
49+
}
50+
51+
// Aggregate expressions with filter are not supported yet.
52+
if (aggregateExpressions.exists(_.filter.isDefined)) {
53+
withInfo(aggregate, "Aggregate expression with filter is not supported")
54+
return None
55+
}
56+
57+
if (groupingExpressions.exists(expr =>
58+
expr.dataType match {
59+
case _: MapType => true
60+
case _ => false
61+
})) {
62+
withInfo(aggregate, "Grouping on map types is not supported")
63+
return None
64+
}
65+
66+
val groupingExprsWithInput =
67+
groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output))
68+
69+
val emptyExprs = groupingExprsWithInput.collect {
70+
case (expr, proto) if proto.isEmpty => expr
71+
}
72+
73+
if (emptyExprs.nonEmpty) {
74+
withInfo(aggregate, s"Unsupported group expressions: ${emptyExprs.mkString(", ")}")
75+
return None
76+
}
77+
78+
val groupingExprs = groupingExprsWithInput.map(_._2)
79+
80+
// In some of the cases, the aggregateExpressions could be empty.
81+
// For example, if the aggregate functions only have group by or if the aggregate
82+
// functions only have distinct aggregate functions:
83+
//
84+
// SELECT COUNT(distinct col2), col1 FROM test group by col1
85+
// +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] )
86+
// +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36]
87+
// +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] )
88+
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
89+
// +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ...
90+
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
91+
// +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ......
92+
// If the aggregateExpressions is empty, we only want to build groupingExpressions,
93+
// and skip processing of aggregateExpressions.
94+
if (aggregateExpressions.isEmpty) {
95+
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
96+
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
97+
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
98+
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
99+
if (resultExprs.exists(_.isEmpty)) {
100+
withInfo(
101+
aggregate,
102+
s"Unsupported result expressions found in: $resultExpressions",
103+
resultExpressions: _*)
104+
return None
105+
}
106+
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
107+
Some(builder.setHashAgg(hashAggBuilder).build())
108+
} else {
109+
val modes = aggregateExpressions.map(_.mode).distinct
110+
111+
if (modes.size != 1) {
112+
// This shouldn't happen as all aggregation expressions should share the same mode.
113+
// Fallback to Spark nevertheless here.
114+
withInfo(aggregate, "All aggregate expressions do not have the same mode")
115+
return None
116+
}
117+
118+
val mode = modes.head match {
119+
case Partial => CometAggregateMode.Partial
120+
case Final => CometAggregateMode.Final
121+
case _ =>
122+
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
123+
return None
124+
}
125+
126+
// In final mode, the aggregate expressions are bound to the output of the
127+
// child and partial aggregate expressions buffer attributes produced by partial
128+
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
129+
// we don't have to do this because we don't use the merging expression.
130+
val binding = mode != CometAggregateMode.Final
131+
// `output` is only used when `binding` is true (i.e., non-Final)
132+
val output = child.output
133+
134+
val aggExprs =
135+
aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
136+
if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
137+
aggExprs.forall(_.isDefined)) {
138+
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
139+
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
140+
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
141+
if (mode == CometAggregateMode.Final) {
142+
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
143+
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
144+
if (resultExprs.exists(_.isEmpty)) {
145+
withInfo(
146+
aggregate,
147+
s"Unsupported result expressions found in: $resultExpressions",
148+
resultExpressions: _*)
149+
return None
150+
}
151+
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
152+
}
153+
hashAggBuilder.setModeValue(mode.getNumber)
154+
Some(builder.setHashAgg(hashAggBuilder).build())
155+
} else {
156+
val allChildren: Seq[Expression] =
157+
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes
158+
withInfo(aggregate, allChildren: _*)
159+
None
160+
}
161+
}
162+
163+
}
164+
165+
}
166+
167+
object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with CometBaseAggregate {
168+
169+
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
170+
CometConf.COMET_EXEC_AGGREGATE_ENABLED)
171+
172+
override def convert(
173+
aggregate: HashAggregateExec,
174+
builder: Operator.Builder,
175+
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
176+
doConvert(aggregate, builder, childOp: _*)
177+
}
178+
}
179+
180+
object CometObjectHashAggregate
181+
extends CometOperatorSerde[ObjectHashAggregateExec]
182+
with CometBaseAggregate {
183+
184+
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
185+
CometConf.COMET_EXEC_AGGREGATE_ENABLED)
186+
187+
override def convert(
188+
aggregate: ObjectHashAggregateExec,
189+
builder: Operator.Builder,
190+
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
191+
doConvert(aggregate, builder, childOp: _*)
192+
}
193+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.Attribute
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
24+
import org.apache.spark.sql.internal.SQLConf
25+
26+
/**
27+
* Trait for providing serialization logic for aggregate expressions.
28+
*/
29+
trait CometAggregateExpressionSerde[T <: AggregateFunction] {
30+
31+
/**
32+
* Get a short name for the expression that can be used as part of a config key related to the
33+
* expression, such as enabling or disabling that expression.
34+
*
35+
* @param expr
36+
* The Spark expression.
37+
* @return
38+
* Short name for the expression, defaulting to the Spark class name
39+
*/
40+
def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
41+
42+
/**
43+
* Convert a Spark expression into a protocol buffer representation that can be passed into
44+
* native code.
45+
*
46+
* @param aggExpr
47+
* The aggregate expression.
48+
* @param expr
49+
* The aggregate function.
50+
* @param inputs
51+
* The input attributes.
52+
* @param binding
53+
* Whether the attributes are bound (this is only relevant in aggregate expressions).
54+
* @param conf
55+
* SQLConf
56+
* @return
57+
* Protocol buffer representation, or None if the expression could not be converted. In this
58+
* case it is expected that the input expression will have been tagged with reasons why it
59+
* could not be converted.
60+
*/
61+
def convert(
62+
aggExpr: AggregateExpression,
63+
expr: T,
64+
inputs: Seq[Attribute],
65+
binding: Boolean,
66+
conf: SQLConf): Option[ExprOuterClass.AggExpr]
67+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import scala.jdk.CollectionConverters._
23+
24+
import org.apache.spark.sql.catalyst.expressions.Expression
25+
import org.apache.spark.sql.execution.ExpandExec
26+
27+
import org.apache.comet.{CometConf, ConfigEntry}
28+
import org.apache.comet.CometSparkSessionExtensions.withInfo
29+
import org.apache.comet.serde.OperatorOuterClass.Operator
30+
import org.apache.comet.serde.QueryPlanSerde.exprToProto
31+
32+
object CometExpand extends CometOperatorSerde[ExpandExec] {
33+
34+
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
35+
CometConf.COMET_EXEC_EXPAND_ENABLED)
36+
37+
override def convert(
38+
op: ExpandExec,
39+
builder: Operator.Builder,
40+
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
41+
var allProjExprs: Seq[Expression] = Seq()
42+
val projExprs = op.projections.flatMap(_.map(e => {
43+
allProjExprs = allProjExprs :+ e
44+
exprToProto(e, op.child.output)
45+
}))
46+
47+
if (projExprs.forall(_.isDefined) && childOp.nonEmpty) {
48+
val expandBuilder = OperatorOuterClass.Expand
49+
.newBuilder()
50+
.addAllProjectList(projExprs.map(_.get).asJava)
51+
.setNumExprPerProject(op.projections.head.size)
52+
Some(builder.setExpand(expandBuilder).build())
53+
} else {
54+
withInfo(op, allProjExprs: _*)
55+
None
56+
}
57+
58+
}
59+
60+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
23+
24+
/**
25+
* Trait for providing serialization logic for expressions.
26+
*/
27+
trait CometExpressionSerde[T <: Expression] {
28+
29+
/**
30+
* Get a short name for the expression that can be used as part of a config key related to the
31+
* expression, such as enabling or disabling that expression.
32+
*
33+
* @param expr
34+
* The Spark expression.
35+
* @return
36+
* Short name for the expression, defaulting to the Spark class name
37+
*/
38+
def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
39+
40+
/**
41+
* Determine the support level of the expression based on its attributes.
42+
*
43+
* @param expr
44+
* The Spark expression.
45+
* @return
46+
* Support level (Compatible, Incompatible, or Unsupported).
47+
*/
48+
def getSupportLevel(expr: T): SupportLevel = Compatible(None)
49+
50+
/**
51+
* Convert a Spark expression into a protocol buffer representation that can be passed into
52+
* native code.
53+
*
54+
* @param expr
55+
* The Spark expression.
56+
* @param inputs
57+
* The input attributes.
58+
* @param binding
59+
* Whether the attributes are bound (this is only relevant in aggregate expressions).
60+
* @return
61+
* Protocol buffer representation, or None if the expression could not be converted. In this
62+
* case it is expected that the input expression will have been tagged with reasons why it
63+
* could not be converted.
64+
*/
65+
def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr]
66+
}

0 commit comments

Comments
 (0)