diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 75d097d5f7..b2d4935d9b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -138,7 +138,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Sqrt] -> CometScalarFunction("sqrt"), classOf[Subtract] -> CometSubtract, classOf[Tan] -> CometScalarFunction("tan"), - // TODO UnaryMinus + classOf[UnaryMinus] -> CometUnaryMinus, classOf[Unhex] -> CometUnhex) private val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( @@ -942,22 +942,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case UnaryMinus(child, failOnError) => - val childExpr = exprToProtoInternal(child, inputs, binding) - if (childExpr.isDefined) { - val builder = ExprOuterClass.UnaryMinus.newBuilder() - builder.setChild(childExpr.get) - builder.setFailOnError(failOnError) - Some( - ExprOuterClass.Expr - .newBuilder() - .setUnaryMinus(builder) - .build()) - } else { - withInfo(expr, child) - None - } - // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. // See https://github.com/apache/spark/pull/38151 diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 52a9370386..0f1eeb758a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.math.min -import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Round, Subtract} +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Round, Subtract, UnaryMinus} import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -370,3 +370,25 @@ object CometRound extends CometExpressionSerde[Round] { } } +object CometUnaryMinus extends CometExpressionSerde[UnaryMinus] { + + override def convert( + expr: UnaryMinus, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + if (childExpr.isDefined) { + val builder = ExprOuterClass.UnaryMinus.newBuilder() + builder.setChild(childExpr.get) + builder.setFailOnError(expr.failOnError) + Some( + ExprOuterClass.Expr + .newBuilder() + .setUnaryMinus(builder) + .build()) + } else { + withInfo(expr, expr.child) + None + } + } +}