Skip to content

Commit 46057a7

Browse files
authored
chore: Refactor serde for CheckOverflow (#2537)
1 parent 8366e1e commit 46057a7

File tree

5 files changed

+102
-25
lines changed

5 files changed

+102
-25
lines changed

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ jobs:
103103
value: |
104104
org.apache.comet.CometFuzzTestSuite
105105
org.apache.comet.CometFuzzAggregateSuite
106+
org.apache.comet.CometFuzzMathSuite
106107
org.apache.comet.DataGeneratorSuite
107108
- name: "shuffle"
108109
value: |

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jobs:
6868
value: |
6969
org.apache.comet.CometFuzzTestSuite
7070
org.apache.comet.CometFuzzAggregateSuite
71+
org.apache.comet.CometFuzzMathSuite
7172
org.apache.comet.DataGeneratorSuite
7273
- name: "shuffle"
7374
value: |

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
221221
private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
222222
// TODO SortOrder (?)
223223
// TODO PromotePrecision
224-
// TODO CheckOverflow
225224
// TODO KnownFloatingPointNormalized
226225
// TODO ScalarSubquery
227226
// TODO UnscaledValue
@@ -230,6 +229,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
230229
// TODO RegExpReplace
231230
classOf[Alias] -> CometAlias,
232231
classOf[AttributeReference] -> CometAttributeReference,
232+
classOf[CheckOverflow] -> CometCheckOverflow,
233233
classOf[Coalesce] -> CometCoalesce,
234234
classOf[Literal] -> CometLiteral,
235235
classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId,
@@ -772,28 +772,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
772772
// `PromotePrecision` is just a wrapper, don't need to serialize it.
773773
exprToProtoInternal(child, inputs, binding)
774774

775-
case CheckOverflow(child, dt, nullOnOverflow) =>
776-
val childExpr = exprToProtoInternal(child, inputs, binding)
777-
778-
if (childExpr.isDefined) {
779-
val builder = ExprOuterClass.CheckOverflow.newBuilder()
780-
builder.setChild(childExpr.get)
781-
builder.setFailOnError(!nullOnOverflow)
782-
783-
// `dataType` must be decimal type
784-
val dataType = serializeDataType(dt)
785-
builder.setDatatype(dataType.get)
786-
787-
Some(
788-
ExprOuterClass.Expr
789-
.newBuilder()
790-
.setCheckOverflow(builder)
791-
.build())
792-
} else {
793-
withInfo(expr, child)
794-
None
795-
}
796-
797775
case RegExpReplace(subject, pattern, replacement, startPosition) =>
798776
if (!RegExp.isSupportedPattern(pattern.toString) &&
799777
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {

spark/src/main/scala/org/apache/comet/serde/math.scala

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
package org.apache.comet.serde
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex}
22+
import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex}
2323
import org.apache.spark.sql.types.DecimalType
2424

2525
import org.apache.comet.CometSparkSessionExtensions.withInfo
26-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}
26+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType}
2727

2828
object CometAtan2 extends CometExpressionSerde[Atan2] {
2929
override def convert(
@@ -143,3 +143,41 @@ sealed trait MathExprBase {
143143
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
144144
}
145145
}
146+
147+
object CometCheckOverflow extends CometExpressionSerde[CheckOverflow] {
148+
149+
override def getSupportLevel(expr: CheckOverflow): SupportLevel = {
150+
if (expr.dataType.isInstanceOf[DecimalType]) {
151+
Compatible()
152+
} else {
153+
Unsupported(Some("dataType must be DecimalType"))
154+
}
155+
}
156+
157+
override def convert(
158+
expr: CheckOverflow,
159+
inputs: Seq[Attribute],
160+
binding: Boolean): Option[ExprOuterClass.Expr] = {
161+
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
162+
163+
if (childExpr.isDefined) {
164+
val builder = ExprOuterClass.CheckOverflow.newBuilder()
165+
builder.setChild(childExpr.get)
166+
builder.setFailOnError(!expr.nullOnOverflow)
167+
168+
// `dataType` must be decimal type
169+
assert(expr.dataType.isInstanceOf[DecimalType])
170+
val dataType = serializeDataType(expr.dataType)
171+
builder.setDatatype(dataType.get)
172+
173+
Some(
174+
ExprOuterClass.Expr
175+
.newBuilder()
176+
.setCheckOverflow(builder)
177+
.build())
178+
} else {
179+
withInfo(expr, expr.child)
180+
None
181+
}
182+
}
183+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType}
23+
24+
class CometFuzzMathSuite extends CometFuzzTestBase {
25+
26+
for (op <- Seq("+", "-", "*", "/", "div")) {
27+
test(s"integer math: $op") {
28+
val df = spark.read.parquet(filename)
29+
val cols = df.schema.fields
30+
.filter(_.dataType match {
31+
case _: IntegerType => true
32+
case _: LongType => true
33+
case _ => false
34+
})
35+
.map(_.name)
36+
df.createOrReplaceTempView("t1")
37+
val sql =
38+
s"SELECT ${cols(0)} $op ${cols(1)} FROM t1 ORDER BY ${cols(0)}, ${cols(1)} LIMIT 500"
39+
if (op == "div") {
40+
// cast(cast(c3#1975 as bigint) as decimal(19,0)) is not fully compatible with Spark (No overflow check)
41+
checkSparkAnswer(sql)
42+
} else {
43+
checkSparkAnswerAndOperator(sql)
44+
}
45+
}
46+
}
47+
48+
for (op <- Seq("+", "-", "*", "/", "div")) {
49+
test(s"decimal math: $op") {
50+
val df = spark.read.parquet(filename)
51+
val cols = df.schema.fields.filter(_.dataType.isInstanceOf[DecimalType]).map(_.name)
52+
df.createOrReplaceTempView("t1")
53+
val sql =
54+
s"SELECT ${cols(0)} $op ${cols(1)} FROM t1 ORDER BY ${cols(0)}, ${cols(1)} LIMIT 500"
55+
checkSparkAnswerAndOperator(sql)
56+
}
57+
}
58+
59+
}

0 commit comments

Comments
 (0)