diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3e0e837c9c..6a8ee0d9d5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero -import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution @@ -229,7 +228,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Literal] -> CometLiteral, classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, classOf[SparkPartitionID] -> CometSparkPartitionId, - classOf[SortOrder] -> CometSortOrder) + classOf[SortOrder] -> CometSortOrder, + classOf[StaticInvoke] -> CometStaticInvoke) /** * Mapping of Spark expression class to Comet expression handler. @@ -697,30 +697,6 @@ object QueryPlanSerde extends Logging with CometExprShim { // `PromotePrecision` is just a wrapper, don't need to serialize it. exprToProtoInternal(child, inputs, binding) - // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for - // char types. - // See https://github.com/apache/spark/pull/38151 - case s: StaticInvoke - // classOf gets ther runtime class of T, which lets us compare directly - // Otherwise isInstanceOf[Class[T]] will always evaluate to true for Class[_] - if s.staticObject == classOf[CharVarcharCodegenUtils] && - s.dataType.isInstanceOf[StringType] && - s.functionName == "readSidePadding" && - s.arguments.size == 2 && - s.propagateNull && - !s.returnNullable && - s.isDeterministic => - val argsExpr = Seq( - exprToProtoInternal(Cast(s.arguments(0), StringType), inputs, binding), - exprToProtoInternal(s.arguments(1), inputs, binding)) - - if (argsExpr.forall(_.isDefined)) { - scalarFunctionExprToProto("read_side_padding", argsExpr: _*) - } else { - withInfo(expr, s.arguments: _*) - None - } - case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => val dataType = serializeDataType(expr.dataType) if (dataType.isEmpty) { diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala new file mode 100644 index 0000000000..0737644ab9 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils + +import org.apache.comet.CometSparkSessionExtensions.withInfo + +object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { + + // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for + // char types. + // See https://github.com/apache/spark/pull/38151 + private val staticInvokeExpressions + : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = + Map( + ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( + "read_side_padding")) + + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + staticInvokeExpressions.get((expr.functionName, expr.staticObject)) match { + case Some(handler) => + handler.convert(expr, inputs, binding) + case None => + withInfo( + expr, + s"Static invoke expression: ${expr.functionName} is not supported", + expr.children: _*) + None + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 3d08c01a7d..2479a41a37 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -139,7 +139,9 @@ class CometStringExpressionSuite extends CometTestBase { } else { // Comet will fall back to Spark because the plan contains a staticinvoke instruction // which is not supported - checkSparkAnswerAndFallbackReason(sql, "staticinvoke is not supported") + checkSparkAnswerAndFallbackReason( + sql, + s"Static invoke expression: $expr is not supported") } } }