Skip to content

Commit 7e0ff1a

Browse files
kazantsev-maksimKazantsev Maksim
andauthored
Chore: Refactor serde for math expressions (#2259)
* Maths expr refactor * Fix * Format --------- Co-authored-by: Kazantsev Maksim <[email protected]>
1 parent c9ab291 commit 7e0ff1a

File tree

2 files changed

+128
-71
lines changed

2 files changed

+128
-71
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,14 @@ object QueryPlanSerde extends Logging with CometExprShim {
170170
classOf[DateSub] -> CometDateSub,
171171
classOf[TruncDate] -> CometTruncDate,
172172
classOf[TruncTimestamp] -> CometTruncTimestamp,
173-
classOf[Flatten] -> CometFlatten)
173+
classOf[Flatten] -> CometFlatten,
174+
classOf[Atan2] -> CometAtan2,
175+
classOf[Ceil] -> CometCeil,
176+
classOf[Floor] -> CometFloor,
177+
classOf[Log] -> CometLog,
178+
classOf[Log10] -> CometLog10,
179+
classOf[Log2] -> CometLog2,
180+
classOf[Pow] -> CometScalarFunction[Pow]("pow"))
174181

175182
/**
176183
* Mapping of Spark aggregate expression class to Comet expression handler.
@@ -1108,12 +1115,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
11081115
// None
11091116
// }
11101117

1111-
case Atan2(left, right) =>
1112-
val leftExpr = exprToProtoInternal(left, inputs, binding)
1113-
val rightExpr = exprToProtoInternal(right, inputs, binding)
1114-
val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
1115-
optExprWithInfo(optExpr, expr, left, right)
1116-
11171118
case Hex(child) =>
11181119
val childExpr = exprToProtoInternal(child, inputs, binding)
11191120
val optExpr =
@@ -1131,56 +1132,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
11311132
scalarFunctionExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr)
11321133
optExprWithInfo(optExpr, expr, unHex._1)
11331134

1134-
case e @ Ceil(child) =>
1135-
val childExpr = exprToProtoInternal(child, inputs, binding)
1136-
child.dataType match {
1137-
case t: DecimalType if t.scale == 0 => // zero scale is no-op
1138-
childExpr
1139-
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
1140-
withInfo(e, s"Decimal type $t has negative scale")
1141-
None
1142-
case _ =>
1143-
val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", e.dataType, childExpr)
1144-
optExprWithInfo(optExpr, expr, child)
1145-
}
1146-
1147-
case e @ Floor(child) =>
1148-
val childExpr = exprToProtoInternal(child, inputs, binding)
1149-
child.dataType match {
1150-
case t: DecimalType if t.scale == 0 => // zero scale is no-op
1151-
childExpr
1152-
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
1153-
withInfo(e, s"Decimal type $t has negative scale")
1154-
None
1155-
case _ =>
1156-
val optExpr = scalarFunctionExprToProtoWithReturnType("floor", e.dataType, childExpr)
1157-
optExprWithInfo(optExpr, expr, child)
1158-
}
1159-
1160-
// The expression for `log` functions is defined as null on numbers less than or equal
1161-
// to 0. This matches Spark and Hive behavior, where non positive values eval to null
1162-
// instead of NaN or -Infinity.
1163-
case Log(child) =>
1164-
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding)
1165-
val optExpr = scalarFunctionExprToProto("ln", childExpr)
1166-
optExprWithInfo(optExpr, expr, child)
1167-
1168-
case Log10(child) =>
1169-
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding)
1170-
val optExpr = scalarFunctionExprToProto("log10", childExpr)
1171-
optExprWithInfo(optExpr, expr, child)
1172-
1173-
case Log2(child) =>
1174-
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding)
1175-
val optExpr = scalarFunctionExprToProto("log2", childExpr)
1176-
optExprWithInfo(optExpr, expr, child)
1177-
1178-
case Pow(left, right) =>
1179-
val leftExpr = exprToProtoInternal(left, inputs, binding)
1180-
val rightExpr = exprToProtoInternal(right, inputs, binding)
1181-
val optExpr = scalarFunctionExprToProto("pow", leftExpr, rightExpr)
1182-
optExprWithInfo(optExpr, expr, left, right)
1183-
11841135
case RegExpReplace(subject, pattern, replacement, startPosition) =>
11851136
if (!RegExp.isSupportedPattern(pattern.toString) &&
11861137
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
@@ -1265,15 +1216,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
12651216
None
12661217
}
12671218

1268-
case BitwiseAnd(left, right) =>
1269-
createBinaryExpr(
1270-
expr,
1271-
left,
1272-
right,
1273-
inputs,
1274-
binding,
1275-
(builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
1276-
12771219
case n @ Not(In(_, _)) =>
12781220
CometNotIn.convert(n, inputs, binding)
12791221

@@ -1611,11 +1553,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
16111553
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
16121554
}
16131555

1614-
private def nullIfNegative(expression: Expression): Expression = {
1615-
val zero = Literal.default(expression.dataType)
1616-
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
1617-
}
1618-
16191556
/**
16201557
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
16211558
*/
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, Expression, Floor, If, LessThanOrEqual, Literal, Log, Log10, Log2}
23+
import org.apache.spark.sql.types.DecimalType
24+
25+
import org.apache.comet.CometSparkSessionExtensions.withInfo
26+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}
27+
28+
object CometAtan2 extends CometExpressionSerde[Atan2] {
29+
override def convert(
30+
expr: Atan2,
31+
inputs: Seq[Attribute],
32+
binding: Boolean): Option[ExprOuterClass.Expr] = {
33+
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
34+
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
35+
val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
36+
optExprWithInfo(optExpr, expr, expr.left, expr.right)
37+
}
38+
}
39+
40+
object CometCeil extends CometExpressionSerde[Ceil] {
41+
override def convert(
42+
expr: Ceil,
43+
inputs: Seq[Attribute],
44+
binding: Boolean): Option[ExprOuterClass.Expr] = {
45+
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
46+
expr.child.dataType match {
47+
case t: DecimalType if t.scale == 0 => // zero scale is no-op
48+
childExpr
49+
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
50+
withInfo(expr, s"Decimal type $t has negative scale")
51+
None
52+
case _ =>
53+
val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", expr.dataType, childExpr)
54+
optExprWithInfo(optExpr, expr, expr.child)
55+
}
56+
}
57+
}
58+
59+
object CometFloor extends CometExpressionSerde[Floor] {
60+
override def convert(
61+
expr: Floor,
62+
inputs: Seq[Attribute],
63+
binding: Boolean): Option[ExprOuterClass.Expr] = {
64+
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
65+
expr.child.dataType match {
66+
case t: DecimalType if t.scale == 0 => // zero scale is no-op
67+
childExpr
68+
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
69+
withInfo(expr, s"Decimal type $t has negative scale")
70+
None
71+
case _ =>
72+
val optExpr = scalarFunctionExprToProtoWithReturnType("floor", expr.dataType, childExpr)
73+
optExprWithInfo(optExpr, expr, expr.child)
74+
}
75+
}
76+
}
77+
78+
// The expression for `log` functions is defined as null on numbers less than or equal
79+
// to 0. This matches Spark and Hive behavior, where non positive values eval to null
80+
// instead of NaN or -Infinity.
81+
object CometLog extends CometExpressionSerde[Log] with MathExprBase {
82+
override def convert(
83+
expr: Log,
84+
inputs: Seq[Attribute],
85+
binding: Boolean): Option[ExprOuterClass.Expr] = {
86+
val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, binding)
87+
val optExpr = scalarFunctionExprToProto("ln", childExpr)
88+
optExprWithInfo(optExpr, expr, expr.child)
89+
}
90+
}
91+
92+
object CometLog10 extends CometExpressionSerde[Log10] with MathExprBase {
93+
override def convert(
94+
expr: Log10,
95+
inputs: Seq[Attribute],
96+
binding: Boolean): Option[ExprOuterClass.Expr] = {
97+
val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, binding)
98+
val optExpr = scalarFunctionExprToProto("log10", childExpr)
99+
optExprWithInfo(optExpr, expr, expr.child)
100+
}
101+
}
102+
103+
object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase {
104+
override def convert(
105+
expr: Log2,
106+
inputs: Seq[Attribute],
107+
binding: Boolean): Option[ExprOuterClass.Expr] = {
108+
val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, binding)
109+
val optExpr = scalarFunctionExprToProto("log2", childExpr)
110+
optExprWithInfo(optExpr, expr, expr.child)
111+
112+
}
113+
}
114+
115+
sealed trait MathExprBase {
116+
protected def nullIfNegative(expression: Expression): Expression = {
117+
val zero = Literal.default(expression.dataType)
118+
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
119+
}
120+
}

0 commit comments

Comments
 (0)