Skip to content

Commit a583db3

Browse files
Chore: refactor Comparison out of QueryPlanSerde (#2028)
1 parent 3e41b2b commit a583db3

File tree

2 files changed

+219
-99
lines changed

2 files changed

+219
-99
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 11 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
127127
classOf[MapValues] -> CometMapValues,
128128
classOf[MapFromArrays] -> CometMapFromArrays,
129129
classOf[GetMapValue] -> CometMapExtract,
130+
classOf[GreaterThan] -> CometGreaterThan,
131+
classOf[GreaterThanOrEqual] -> CometGreaterThanOrEqual,
132+
classOf[LessThan] -> CometLessThan,
133+
classOf[LessThanOrEqual] -> CometLessThanOrEqual,
134+
classOf[IsNull] -> CometIsNull,
135+
classOf[IsNotNull] -> CometIsNotNull,
136+
classOf[IsNaN] -> CometIsNaN,
137+
classOf[In] -> CometIn,
138+
classOf[InSet] -> CometInSet,
130139
classOf[Rand] -> CometRand,
131140
classOf[Randn] -> CometRandn,
132141
classOf[SparkPartitionID] -> CometSparkPartitionId,
@@ -684,42 +693,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
684693
binding,
685694
(builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr))
686695

687-
case GreaterThan(left, right) =>
688-
createBinaryExpr(
689-
expr,
690-
left,
691-
right,
692-
inputs,
693-
binding,
694-
(builder, binaryExpr) => builder.setGt(binaryExpr))
695-
696-
case GreaterThanOrEqual(left, right) =>
697-
createBinaryExpr(
698-
expr,
699-
left,
700-
right,
701-
inputs,
702-
binding,
703-
(builder, binaryExpr) => builder.setGtEq(binaryExpr))
704-
705-
case LessThan(left, right) =>
706-
createBinaryExpr(
707-
expr,
708-
left,
709-
right,
710-
inputs,
711-
binding,
712-
(builder, binaryExpr) => builder.setLt(binaryExpr))
713-
714-
case LessThanOrEqual(left, right) =>
715-
createBinaryExpr(
716-
expr,
717-
left,
718-
right,
719-
inputs,
720-
binding,
721-
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
722-
723696
case Literal(value, dataType)
724697
if supportedDataType(dataType, allowComplex = value == null) =>
725698
val exprBuilder = ExprOuterClass.Literal.newBuilder()
@@ -1066,29 +1039,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
10661039
})
10671040
optExprWithInfo(optExpr, expr, child)
10681041

1069-
case IsNull(child) =>
1070-
createUnaryExpr(
1071-
expr,
1072-
child,
1073-
inputs,
1074-
binding,
1075-
(builder, unaryExpr) => builder.setIsNull(unaryExpr))
1076-
1077-
case IsNotNull(child) =>
1078-
createUnaryExpr(
1079-
expr,
1080-
child,
1081-
inputs,
1082-
binding,
1083-
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
1084-
1085-
case IsNaN(child) =>
1086-
val childExpr = exprToProtoInternal(child, inputs, binding)
1087-
val optExpr =
1088-
scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr)
1089-
1090-
optExprWithInfo(optExpr, expr, child)
1091-
10921042
case SortOrder(child, direction, nullOrdering, _) =>
10931043
val childExpr = exprToProtoInternal(child, inputs, binding)
10941044

@@ -1458,20 +1408,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
14581408
binding,
14591409
(builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
14601410

1461-
case In(value, list) =>
1462-
in(expr, value, list, inputs, binding, negate = false)
1463-
1464-
case InSet(value, hset) =>
1465-
val valueDataType = value.dataType
1466-
val list = hset.map { setVal =>
1467-
Literal(setVal, valueDataType)
1468-
}.toSeq
1469-
// Change `InSet` to `In` expression
1470-
// We do Spark `InSet` optimization in native (DataFusion) side.
1471-
in(expr, value, list, inputs, binding, negate = false)
1472-
1473-
case Not(In(value, list)) =>
1474-
in(expr, value, list, inputs, binding, negate = true)
1411+
case Not(In(_, _)) =>
1412+
CometNotIn.convert(expr, inputs, binding)
14751413

14761414
case Not(child) =>
14771415
createUnaryExpr(
@@ -1815,32 +1753,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
18151753
}
18161754
}
18171755

1818-
def in(
1819-
expr: Expression,
1820-
value: Expression,
1821-
list: Seq[Expression],
1822-
inputs: Seq[Attribute],
1823-
binding: Boolean,
1824-
negate: Boolean): Option[Expr] = {
1825-
val valueExpr = exprToProtoInternal(value, inputs, binding)
1826-
val listExprs = list.map(exprToProtoInternal(_, inputs, binding))
1827-
if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
1828-
val builder = ExprOuterClass.In.newBuilder()
1829-
builder.setInValue(valueExpr.get)
1830-
builder.addAllLists(listExprs.map(_.get).asJava)
1831-
builder.setNegated(negate)
1832-
Some(
1833-
ExprOuterClass.Expr
1834-
.newBuilder()
1835-
.setIn(builder)
1836-
.build())
1837-
} else {
1838-
val allExprs = list ++ Seq(value)
1839-
withInfo(expr, allExprs: _*)
1840-
None
1841-
}
1842-
}
1843-
18441756
def scalarFunctionExprToProtoWithReturnType(
18451757
funcName: String,
18461758
returnType: DataType,
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import scala.collection.JavaConverters._
23+
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GreaterThan, GreaterThanOrEqual, In, InSet, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not}
25+
import org.apache.spark.sql.types.BooleanType
26+
27+
import org.apache.comet.CometSparkSessionExtensions.withInfo
28+
import org.apache.comet.serde.ExprOuterClass.Expr
29+
import org.apache.comet.serde.QueryPlanSerde._
30+
31+
object CometGreaterThan extends CometExpressionSerde {
32+
override def convert(
33+
expr: Expression,
34+
inputs: Seq[Attribute],
35+
binding: Boolean): Option[ExprOuterClass.Expr] = {
36+
val greaterThan = expr.asInstanceOf[GreaterThan]
37+
38+
createBinaryExpr(
39+
expr,
40+
greaterThan.left,
41+
greaterThan.right,
42+
inputs,
43+
binding,
44+
(builder, binaryExpr) => builder.setGt(binaryExpr))
45+
}
46+
}
47+
48+
object CometGreaterThanOrEqual extends CometExpressionSerde {
49+
override def convert(
50+
expr: Expression,
51+
inputs: Seq[Attribute],
52+
binding: Boolean): Option[ExprOuterClass.Expr] = {
53+
val greaterThanOrEqual = expr.asInstanceOf[GreaterThanOrEqual]
54+
55+
createBinaryExpr(
56+
expr,
57+
greaterThanOrEqual.left,
58+
greaterThanOrEqual.right,
59+
inputs,
60+
binding,
61+
(builder, binaryExpr) => builder.setGtEq(binaryExpr))
62+
}
63+
}
64+
65+
object CometLessThan extends CometExpressionSerde {
66+
override def convert(
67+
expr: Expression,
68+
inputs: Seq[Attribute],
69+
binding: Boolean): Option[ExprOuterClass.Expr] = {
70+
val lessThan = expr.asInstanceOf[LessThan]
71+
72+
createBinaryExpr(
73+
expr,
74+
lessThan.left,
75+
lessThan.right,
76+
inputs,
77+
binding,
78+
(builder, binaryExpr) => builder.setLt(binaryExpr))
79+
}
80+
}
81+
82+
object CometLessThanOrEqual extends CometExpressionSerde {
83+
override def convert(
84+
expr: Expression,
85+
inputs: Seq[Attribute],
86+
binding: Boolean): Option[ExprOuterClass.Expr] = {
87+
val lessThanOrEqual = expr.asInstanceOf[LessThanOrEqual]
88+
89+
createBinaryExpr(
90+
expr,
91+
lessThanOrEqual.left,
92+
lessThanOrEqual.right,
93+
inputs,
94+
binding,
95+
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
96+
}
97+
}
98+
99+
object CometIsNull extends CometExpressionSerde {
100+
override def convert(
101+
expr: Expression,
102+
inputs: Seq[Attribute],
103+
binding: Boolean): Option[ExprOuterClass.Expr] = {
104+
val isNull = expr.asInstanceOf[IsNull]
105+
106+
createUnaryExpr(
107+
expr,
108+
isNull.child,
109+
inputs,
110+
binding,
111+
(builder, unaryExpr) => builder.setIsNull(unaryExpr))
112+
}
113+
}
114+
115+
object CometIsNotNull extends CometExpressionSerde {
116+
override def convert(
117+
expr: Expression,
118+
inputs: Seq[Attribute],
119+
binding: Boolean): Option[ExprOuterClass.Expr] = {
120+
val isNotNull = expr.asInstanceOf[IsNotNull]
121+
122+
createUnaryExpr(
123+
expr,
124+
isNotNull.child,
125+
inputs,
126+
binding,
127+
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
128+
}
129+
}
130+
131+
object CometIsNaN extends CometExpressionSerde {
132+
override def convert(
133+
expr: Expression,
134+
inputs: Seq[Attribute],
135+
binding: Boolean): Option[ExprOuterClass.Expr] = {
136+
val isNaN = expr.asInstanceOf[IsNaN]
137+
val childExpr = exprToProtoInternal(isNaN.child, inputs, binding)
138+
val optExpr = scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr)
139+
140+
optExprWithInfo(optExpr, expr, isNaN.child)
141+
}
142+
}
143+
144+
object CometIn extends CometExpressionSerde {
145+
override def convert(
146+
expr: Expression,
147+
inputs: Seq[Attribute],
148+
binding: Boolean): Option[ExprOuterClass.Expr] = {
149+
val inExpr = expr.asInstanceOf[In]
150+
ComparisonUtils.in(expr, inExpr.value, inExpr.list, inputs, binding, negate = false)
151+
}
152+
}
153+
154+
object CometNotIn extends CometExpressionSerde {
155+
override def convert(
156+
expr: Expression,
157+
inputs: Seq[Attribute],
158+
binding: Boolean): Option[ExprOuterClass.Expr] = {
159+
val notExpr = expr.asInstanceOf[Not]
160+
val inExpr = notExpr.child.asInstanceOf[In]
161+
ComparisonUtils.in(expr, inExpr.value, inExpr.list, inputs, binding, negate = true)
162+
}
163+
}
164+
165+
object CometInSet extends CometExpressionSerde {
166+
override def convert(
167+
expr: Expression,
168+
inputs: Seq[Attribute],
169+
binding: Boolean): Option[ExprOuterClass.Expr] = {
170+
val inSetExpr = expr.asInstanceOf[InSet]
171+
val valueDataType = inSetExpr.child.dataType
172+
val list = inSetExpr.hset.map { setVal =>
173+
Literal(setVal, valueDataType)
174+
}.toSeq
175+
// Change `InSet` to `In` expression
176+
// We do Spark `InSet` optimization in native (DataFusion) side.
177+
ComparisonUtils.in(expr, inSetExpr.child, list, inputs, binding, negate = false)
178+
}
179+
}
180+
181+
object ComparisonUtils {
182+
183+
def in(
184+
expr: Expression,
185+
value: Expression,
186+
list: Seq[Expression],
187+
inputs: Seq[Attribute],
188+
binding: Boolean,
189+
negate: Boolean): Option[Expr] = {
190+
val valueExpr = exprToProtoInternal(value, inputs, binding)
191+
val listExprs = list.map(exprToProtoInternal(_, inputs, binding))
192+
if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
193+
val builder = ExprOuterClass.In.newBuilder()
194+
builder.setInValue(valueExpr.get)
195+
builder.addAllLists(listExprs.map(_.get).asJava)
196+
builder.setNegated(negate)
197+
Some(
198+
ExprOuterClass.Expr
199+
.newBuilder()
200+
.setIn(builder)
201+
.build())
202+
} else {
203+
val allExprs = list ++ Seq(value)
204+
withInfo(expr, allExprs: _*)
205+
None
206+
}
207+
}
208+
}

0 commit comments

Comments
 (0)