@@ -23,7 +23,6 @@ import java.util.Locale
2323
2424import scala .collection .JavaConverters ._
2525import scala .collection .mutable .ListBuffer
26- import scala .math .min
2726
2827import org .apache .spark .internal .Logging
2928import org .apache .spark .sql .catalyst .expressions ._
@@ -67,6 +66,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
6766 * Mapping of Spark expression class to Comet expression handler.
6867 */
6968 private val exprSerdeMap : Map [Class [_], CometExpressionSerde ] = Map (
69+ classOf [Add ] -> CometAdd ,
70+ classOf [Subtract ] -> CometSubtract ,
71+ classOf [Multiply ] -> CometMultiply ,
72+ classOf [Divide ] -> CometDivide ,
73+ classOf [IntegralDivide ] -> CometIntegralDivide ,
74+ classOf [Remainder ] -> CometRemainder ,
7075 classOf [ArrayAppend ] -> CometArrayAppend ,
7176 classOf [ArrayContains ] -> CometArrayContains ,
7277 classOf [ArrayDistinct ] -> CometArrayDistinct ,
@@ -630,141 +635,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
630635 case c @ Cast (child, dt, timeZoneId, _) =>
631636 handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
632637
633- case add @ Add (left, right, _) if supportedDataType(left.dataType) =>
634- createMathExpression(
635- expr,
636- left,
637- right,
638- inputs,
639- binding,
640- add.dataType,
641- add.evalMode == EvalMode .ANSI ,
642- (builder, mathExpr) => builder.setAdd(mathExpr))
643-
644- case add @ Add (left, _, _) if ! supportedDataType(left.dataType) =>
645- withInfo(add, s " Unsupported datatype ${left.dataType}" )
646- None
647-
648- case sub @ Subtract (left, right, _) if supportedDataType(left.dataType) =>
649- createMathExpression(
650- expr,
651- left,
652- right,
653- inputs,
654- binding,
655- sub.dataType,
656- sub.evalMode == EvalMode .ANSI ,
657- (builder, mathExpr) => builder.setSubtract(mathExpr))
658-
659- case sub @ Subtract (left, _, _) if ! supportedDataType(left.dataType) =>
660- withInfo(sub, s " Unsupported datatype ${left.dataType}" )
661- None
662-
663- case mul @ Multiply (left, right, _) if supportedDataType(left.dataType) =>
664- createMathExpression(
665- expr,
666- left,
667- right,
668- inputs,
669- binding,
670- mul.dataType,
671- mul.evalMode == EvalMode .ANSI ,
672- (builder, mathExpr) => builder.setMultiply(mathExpr))
673-
674- case mul @ Multiply (left, _, _) =>
675- if (! supportedDataType(left.dataType)) {
676- withInfo(mul, s " Unsupported datatype ${left.dataType}" )
677- }
678- None
679-
680- case div @ Divide (left, right, _) if supportedDataType(left.dataType) =>
681- // Datafusion now throws an exception for dividing by zero
682- // See https://github.com/apache/arrow-datafusion/pull/6792
683- // For now, use NullIf to swap zeros with nulls.
684- val rightExpr = nullIfWhenPrimitive(right)
685-
686- createMathExpression(
687- expr,
688- left,
689- rightExpr,
690- inputs,
691- binding,
692- div.dataType,
693- div.evalMode == EvalMode .ANSI ,
694- (builder, mathExpr) => builder.setDivide(mathExpr))
695-
696- case div @ Divide (left, _, _) =>
697- if (! supportedDataType(left.dataType)) {
698- withInfo(div, s " Unsupported datatype ${left.dataType}" )
699- }
700- None
701-
702- case div @ IntegralDivide (left, right, _) if supportedDataType(left.dataType) =>
703- val rightExpr = nullIfWhenPrimitive(right)
704-
705- val dataType = (left.dataType, right.dataType) match {
706- case (l : DecimalType , r : DecimalType ) =>
707- // copy from IntegralDivide.resultDecimalType
708- val intDig = l.precision - l.scale + r.scale
709- DecimalType (min(if (intDig == 0 ) 1 else intDig, DecimalType .MAX_PRECISION ), 0 )
710- case _ => left.dataType
711- }
712-
713- val divideExpr = createMathExpression(
714- expr,
715- left,
716- rightExpr,
717- inputs,
718- binding,
719- dataType,
720- div.evalMode == EvalMode .ANSI ,
721- (builder, mathExpr) => builder.setIntegralDivide(mathExpr))
722-
723- if (divideExpr.isDefined) {
724- val childExpr = if (dataType.isInstanceOf [DecimalType ]) {
725- // check overflow for decimal type
726- val builder = ExprOuterClass .CheckOverflow .newBuilder()
727- builder.setChild(divideExpr.get)
728- builder.setFailOnError(div.evalMode == EvalMode .ANSI )
729- builder.setDatatype(serializeDataType(dataType).get)
730- Some (
731- ExprOuterClass .Expr
732- .newBuilder()
733- .setCheckOverflow(builder)
734- .build())
735- } else {
736- divideExpr
737- }
738-
739- // cast result to long
740- castToProto(expr, None , LongType , childExpr.get, CometEvalMode .LEGACY )
741- } else {
742- None
743- }
744-
745- case div @ IntegralDivide (left, _, _) =>
746- if (! supportedDataType(left.dataType)) {
747- withInfo(div, s " Unsupported datatype ${left.dataType}" )
748- }
749- None
750-
751- case rem @ Remainder (left, right, _) if supportedDataType(left.dataType) =>
752- createMathExpression(
753- expr,
754- left,
755- right,
756- inputs,
757- binding,
758- rem.dataType,
759- rem.evalMode == EvalMode .ANSI ,
760- (builder, mathExpr) => builder.setRemainder(mathExpr))
761-
762- case rem @ Remainder (left, _, _) =>
763- if (! supportedDataType(left.dataType)) {
764- withInfo(rem, s " Unsupported datatype ${left.dataType}" )
765- }
766- None
767-
768638 case EqualTo (left, right) =>
769639 createBinaryExpr(
770640 expr,
@@ -1962,42 +1832,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
19621832 }
19631833 }
19641834
1965- private def createMathExpression (
1966- expr : Expression ,
1967- left : Expression ,
1968- right : Expression ,
1969- inputs : Seq [Attribute ],
1970- binding : Boolean ,
1971- dataType : DataType ,
1972- failOnError : Boolean ,
1973- f : (ExprOuterClass .Expr .Builder , ExprOuterClass .MathExpr ) => ExprOuterClass .Expr .Builder )
1974- : Option [ExprOuterClass .Expr ] = {
1975- val leftExpr = exprToProtoInternal(left, inputs, binding)
1976- val rightExpr = exprToProtoInternal(right, inputs, binding)
1977-
1978- if (leftExpr.isDefined && rightExpr.isDefined) {
1979- // create the generic MathExpr message
1980- val builder = ExprOuterClass .MathExpr .newBuilder()
1981- builder.setLeft(leftExpr.get)
1982- builder.setRight(rightExpr.get)
1983- builder.setFailOnError(failOnError)
1984- serializeDataType(dataType).foreach { t =>
1985- builder.setReturnType(t)
1986- }
1987- val inner = builder.build()
1988- // call the user-supplied function to wrap MathExpr in a top-level Expr
1989- // such as Expr.Add or Expr.Divide
1990- Some (
1991- f(
1992- ExprOuterClass .Expr
1993- .newBuilder(),
1994- inner).build())
1995- } else {
1996- withInfo(expr, left, right)
1997- None
1998- }
1999- }
2000-
20011835 def in (
20021836 expr : Expression ,
20031837 value : Expression ,
@@ -2053,25 +1887,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
20531887 Some (ExprOuterClass .Expr .newBuilder().setScalarFunc(builder).build())
20541888 }
20551889
2056- private def isPrimitive (expression : Expression ): Boolean = expression.dataType match {
2057- case _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType | _ : FloatType |
2058- _ : DoubleType | _ : TimestampType | _ : DateType | _ : BooleanType | _ : DecimalType =>
2059- true
2060- case _ => false
2061- }
2062-
2063- private def nullIfWhenPrimitive (expression : Expression ): Expression =
2064- if (isPrimitive(expression)) {
2065- val zero = Literal .default(expression.dataType)
2066- expression match {
2067- case _ : Literal if expression != zero => expression
2068- case _ =>
2069- If (EqualTo (expression, zero), Literal .create(null , expression.dataType), expression)
2070- }
2071- } else {
2072- expression
2073- }
2074-
20751890 private def nullIfNegative (expression : Expression ): Expression = {
20761891 val zero = Literal .default(expression.dataType)
20771892 If (LessThanOrEqual (expression, zero), Literal .create(null , expression.dataType), expression)
0 commit comments