Skip to content

Commit 2811793

Browse files
authored
refactor UnaryMinus serde (#2378)
1 parent 72eb0e9 commit 2811793

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
138138
classOf[Sqrt] -> CometScalarFunction("sqrt"),
139139
classOf[Subtract] -> CometSubtract,
140140
classOf[Tan] -> CometScalarFunction("tan"),
141-
// TODO UnaryMinus
141+
classOf[UnaryMinus] -> CometUnaryMinus,
142142
classOf[Unhex] -> CometUnhex)
143143

144144
private val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
@@ -942,22 +942,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
942942
None
943943
}
944944

945-
case UnaryMinus(child, failOnError) =>
946-
val childExpr = exprToProtoInternal(child, inputs, binding)
947-
if (childExpr.isDefined) {
948-
val builder = ExprOuterClass.UnaryMinus.newBuilder()
949-
builder.setChild(childExpr.get)
950-
builder.setFailOnError(failOnError)
951-
Some(
952-
ExprOuterClass.Expr
953-
.newBuilder()
954-
.setUnaryMinus(builder)
955-
.build())
956-
} else {
957-
withInfo(expr, child)
958-
None
959-
}
960-
961945
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
962946
// char types.
963947
// See https://github.com/apache/spark/pull/38151

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.math.min
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Round, Subtract}
24+
import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Round, Subtract, UnaryMinus}
2525
import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType}
2626

2727
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -370,3 +370,25 @@ object CometRound extends CometExpressionSerde[Round] {
370370

371371
}
372372
}
373+
object CometUnaryMinus extends CometExpressionSerde[UnaryMinus] {
374+
375+
override def convert(
376+
expr: UnaryMinus,
377+
inputs: Seq[Attribute],
378+
binding: Boolean): Option[ExprOuterClass.Expr] = {
379+
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
380+
if (childExpr.isDefined) {
381+
val builder = ExprOuterClass.UnaryMinus.newBuilder()
382+
builder.setChild(childExpr.get)
383+
builder.setFailOnError(expr.failOnError)
384+
Some(
385+
ExprOuterClass.Expr
386+
.newBuilder()
387+
.setUnaryMinus(builder)
388+
.build())
389+
} else {
390+
withInfo(expr, expr.child)
391+
None
392+
}
393+
}
394+
}

0 commit comments

Comments
 (0)