Skip to content

Commit 1553caa

Browse files
authored
chore: Refactor Literal serde (#2377)
1 parent 3b29cb9 commit 1553caa

File tree

2 files changed

+196
-150
lines changed

2 files changed

+196
-150
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 5 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2828
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
2929
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
3030
import org.apache.spark.sql.catalyst.plans._
31-
import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData}
31+
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
3232
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
3333
import org.apache.spark.sql.comet._
3434
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
@@ -43,21 +43,17 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, Sh
4343
import org.apache.spark.sql.execution.window.WindowExec
4444
import org.apache.spark.sql.internal.SQLConf
4545
import org.apache.spark.sql.types._
46-
import org.apache.spark.unsafe.types.UTF8String
47-
48-
import com.google.protobuf.ByteString
4946

5047
import org.apache.comet.{CometConf, ConfigEntry}
5148
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
52-
import org.apache.comet.DataTypeSupport.isComplexType
5349
import org.apache.comet.expressions._
5450
import org.apache.comet.objectstore.NativeConfig
5551
import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc}
5652
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
5753
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
5854
import org.apache.comet.serde.Types.{DataType => ProtoDataType}
5955
import org.apache.comet.serde.Types.DataType._
60-
import org.apache.comet.serde.Types.ListLiteral
56+
import org.apache.comet.serde.literals.CometLiteral
6157
import org.apache.comet.shims.CometExprShim
6258

6359
/**
@@ -213,7 +209,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
213209
classOf[Cast] -> CometCast)
214210

215211
private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
216-
// TODO Literal
217212
// TODO SortOrder (?)
218213
// TODO PromotePrecision
219214
// TODO CheckOverflow
@@ -225,9 +220,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
225220
// TODO RegExpReplace
226221
classOf[Alias] -> CometAlias,
227222
classOf[AttributeReference] -> CometAttributeReference,
228-
classOf[SparkPartitionID] -> CometSparkPartitionId,
223+
classOf[Coalesce] -> CometCoalesce,
224+
classOf[Literal] -> CometLiteral,
229225
classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId,
230-
classOf[Coalesce] -> CometCoalesce)
226+
classOf[SparkPartitionID] -> CometSparkPartitionId)
231227

232228
/**
233229
* Mapping of Spark expression class to Comet expression handler.
@@ -677,147 +673,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
677673
val cast = Cast(child, expr.dataType, Some(timeZoneId), EvalMode.TRY)
678674
convert(cast, CometCast)
679675

680-
case Literal(value, dataType)
681-
if supportedDataType(
682-
dataType,
683-
allowComplex = value == null ||
684-
// Nested literal support for native reader
685-
// can be tracked https://github.com/apache/datafusion-comet/issues/1937
686-
// now supports only Array of primitive
687-
(Seq(CometConf.SCAN_NATIVE_ICEBERG_COMPAT, CometConf.SCAN_NATIVE_DATAFUSION)
688-
.contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && dataType
689-
.isInstanceOf[ArrayType]) && !isComplexType(
690-
dataType.asInstanceOf[ArrayType].elementType)) =>
691-
val exprBuilder = LiteralOuterClass.Literal.newBuilder()
692-
693-
if (value == null) {
694-
exprBuilder.setIsNull(true)
695-
} else {
696-
exprBuilder.setIsNull(false)
697-
dataType match {
698-
case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
699-
case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte])
700-
case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short])
701-
case _: IntegerType | _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
702-
case _: LongType | _: TimestampType | _: TimestampNTZType =>
703-
exprBuilder.setLongVal(value.asInstanceOf[Long])
704-
case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float])
705-
case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double])
706-
case _: StringType =>
707-
exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString)
708-
case _: DecimalType =>
709-
// Pass decimal literal as bytes.
710-
val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue
711-
exprBuilder.setDecimalVal(
712-
com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
713-
case _: BinaryType =>
714-
val byteStr =
715-
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
716-
exprBuilder.setBytesVal(byteStr)
717-
case a: ArrayType =>
718-
val listLiteralBuilder = ListLiteral.newBuilder()
719-
val array = value.asInstanceOf[GenericArrayData].array
720-
a.elementType match {
721-
case NullType =>
722-
array.foreach(_ => listLiteralBuilder.addNullMask(true))
723-
case BooleanType =>
724-
array.foreach(v => {
725-
val casted = v.asInstanceOf[java.lang.Boolean]
726-
listLiteralBuilder.addBooleanValues(casted)
727-
listLiteralBuilder.addNullMask(casted != null)
728-
})
729-
case ByteType =>
730-
array.foreach(v => {
731-
val casted = v.asInstanceOf[java.lang.Integer]
732-
listLiteralBuilder.addByteValues(casted)
733-
listLiteralBuilder.addNullMask(casted != null)
734-
})
735-
case ShortType =>
736-
array.foreach(v => {
737-
val casted = v.asInstanceOf[java.lang.Short]
738-
listLiteralBuilder.addShortValues(
739-
if (casted != null) casted.intValue()
740-
else null.asInstanceOf[java.lang.Integer])
741-
listLiteralBuilder.addNullMask(casted != null)
742-
})
743-
case IntegerType | DateType =>
744-
array.foreach(v => {
745-
val casted = v.asInstanceOf[java.lang.Integer]
746-
listLiteralBuilder.addIntValues(casted)
747-
listLiteralBuilder.addNullMask(casted != null)
748-
})
749-
case LongType | TimestampType | TimestampNTZType =>
750-
array.foreach(v => {
751-
val casted = v.asInstanceOf[java.lang.Long]
752-
listLiteralBuilder.addLongValues(casted)
753-
listLiteralBuilder.addNullMask(casted != null)
754-
})
755-
case FloatType =>
756-
array.foreach(v => {
757-
val casted = v.asInstanceOf[java.lang.Float]
758-
listLiteralBuilder.addFloatValues(casted)
759-
listLiteralBuilder.addNullMask(casted != null)
760-
})
761-
case DoubleType =>
762-
array.foreach(v => {
763-
val casted = v.asInstanceOf[java.lang.Double]
764-
listLiteralBuilder.addDoubleValues(casted)
765-
listLiteralBuilder.addNullMask(casted != null)
766-
})
767-
case StringType =>
768-
array.foreach(v => {
769-
val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String]
770-
listLiteralBuilder.addStringValues(
771-
if (casted != null) casted.toString else "")
772-
listLiteralBuilder.addNullMask(casted != null)
773-
})
774-
case _: DecimalType =>
775-
array
776-
.foreach(v => {
777-
val casted =
778-
v.asInstanceOf[Decimal]
779-
listLiteralBuilder.addDecimalValues(if (casted != null) {
780-
com.google.protobuf.ByteString
781-
.copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray)
782-
} else ByteString.EMPTY)
783-
listLiteralBuilder.addNullMask(casted != null)
784-
})
785-
case _: BinaryType =>
786-
array
787-
.foreach(v => {
788-
val casted =
789-
v.asInstanceOf[Array[Byte]]
790-
listLiteralBuilder.addBytesValues(if (casted != null) {
791-
com.google.protobuf.ByteString.copyFrom(casted)
792-
} else ByteString.EMPTY)
793-
listLiteralBuilder.addNullMask(casted != null)
794-
})
795-
}
796-
exprBuilder.setListVal(listLiteralBuilder.build())
797-
exprBuilder.setDatatype(serializeDataType(dataType).get)
798-
case dt =>
799-
logWarning(s"Unexpected datatype '$dt' for literal value '$value'")
800-
}
801-
}
802-
803-
val dt = serializeDataType(dataType)
804-
805-
if (dt.isDefined) {
806-
exprBuilder.setDatatype(dt.get)
807-
808-
Some(
809-
ExprOuterClass.Expr
810-
.newBuilder()
811-
.setLiteral(exprBuilder)
812-
.build())
813-
} else {
814-
withInfo(expr, s"Unsupported datatype $dataType")
815-
None
816-
}
817-
case Literal(_, dataType) if !supportedDataType(dataType) =>
818-
withInfo(expr, s"Unsupported datatype $dataType")
819-
None
820-
821676
// ToPrettyString is new in Spark 3.5
822677
case _
823678
if expr.getClass.getSimpleName == "ToPrettyString" && expr
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde.literals
21+
22+
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
24+
import org.apache.spark.sql.catalyst.util.GenericArrayData
25+
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, TimestampNTZType, TimestampType}
26+
import org.apache.spark.unsafe.types.UTF8String
27+
28+
import com.google.protobuf.ByteString
29+
30+
import org.apache.comet.CometConf
31+
import org.apache.comet.CometSparkSessionExtensions.withInfo
32+
import org.apache.comet.DataTypeSupport.isComplexType
33+
import org.apache.comet.serde.{CometExpressionSerde, Compatible, ExprOuterClass, LiteralOuterClass, SupportLevel, Unsupported}
34+
import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, supportedDataType}
35+
import org.apache.comet.serde.Types.ListLiteral
36+
37+
object CometLiteral extends CometExpressionSerde[Literal] with Logging {
38+
39+
override def getSupportLevel(expr: Literal): SupportLevel = {
40+
41+
if (supportedDataType(
42+
expr.dataType,
43+
allowComplex = expr.value == null ||
44+
// Nested literal support for native reader
45+
// can be tracked https://github.com/apache/datafusion-comet/issues/1937
46+
// now supports only Array of primitive
47+
(Seq(CometConf.SCAN_NATIVE_ICEBERG_COMPAT, CometConf.SCAN_NATIVE_DATAFUSION)
48+
.contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && expr.dataType
49+
.isInstanceOf[ArrayType]) && !isComplexType(
50+
expr.dataType.asInstanceOf[ArrayType].elementType))) {
51+
Compatible(None)
52+
} else {
53+
Unsupported(Some(s"Unsupported data type ${expr.dataType}"))
54+
}
55+
}
56+
57+
override def convert(
58+
expr: Literal,
59+
inputs: Seq[Attribute],
60+
binding: Boolean): Option[ExprOuterClass.Expr] = {
61+
val dataType = expr.dataType
62+
val value = expr.value
63+
64+
val exprBuilder = LiteralOuterClass.Literal.newBuilder()
65+
66+
if (value == null) {
67+
exprBuilder.setIsNull(true)
68+
} else {
69+
exprBuilder.setIsNull(false)
70+
dataType match {
71+
case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
72+
case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte])
73+
case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short])
74+
case _: IntegerType | _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
75+
case _: LongType | _: TimestampType | _: TimestampNTZType =>
76+
exprBuilder.setLongVal(value.asInstanceOf[Long])
77+
case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float])
78+
case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double])
79+
case _: StringType =>
80+
exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString)
81+
case _: DecimalType =>
82+
// Pass decimal literal as bytes.
83+
val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue
84+
exprBuilder.setDecimalVal(com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
85+
case _: BinaryType =>
86+
val byteStr =
87+
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
88+
exprBuilder.setBytesVal(byteStr)
89+
case a: ArrayType =>
90+
val listLiteralBuilder = ListLiteral.newBuilder()
91+
val array = value.asInstanceOf[GenericArrayData].array
92+
a.elementType match {
93+
case NullType =>
94+
array.foreach(_ => listLiteralBuilder.addNullMask(true))
95+
case BooleanType =>
96+
array.foreach(v => {
97+
val casted = v.asInstanceOf[java.lang.Boolean]
98+
listLiteralBuilder.addBooleanValues(casted)
99+
listLiteralBuilder.addNullMask(casted != null)
100+
})
101+
case ByteType =>
102+
array.foreach(v => {
103+
val casted = v.asInstanceOf[java.lang.Integer]
104+
listLiteralBuilder.addByteValues(casted)
105+
listLiteralBuilder.addNullMask(casted != null)
106+
})
107+
case ShortType =>
108+
array.foreach(v => {
109+
val casted = v.asInstanceOf[java.lang.Short]
110+
listLiteralBuilder.addShortValues(
111+
if (casted != null) casted.intValue()
112+
else null.asInstanceOf[java.lang.Integer])
113+
listLiteralBuilder.addNullMask(casted != null)
114+
})
115+
case IntegerType | DateType =>
116+
array.foreach(v => {
117+
val casted = v.asInstanceOf[java.lang.Integer]
118+
listLiteralBuilder.addIntValues(casted)
119+
listLiteralBuilder.addNullMask(casted != null)
120+
})
121+
case LongType | TimestampType | TimestampNTZType =>
122+
array.foreach(v => {
123+
val casted = v.asInstanceOf[java.lang.Long]
124+
listLiteralBuilder.addLongValues(casted)
125+
listLiteralBuilder.addNullMask(casted != null)
126+
})
127+
case FloatType =>
128+
array.foreach(v => {
129+
val casted = v.asInstanceOf[java.lang.Float]
130+
listLiteralBuilder.addFloatValues(casted)
131+
listLiteralBuilder.addNullMask(casted != null)
132+
})
133+
case DoubleType =>
134+
array.foreach(v => {
135+
val casted = v.asInstanceOf[java.lang.Double]
136+
listLiteralBuilder.addDoubleValues(casted)
137+
listLiteralBuilder.addNullMask(casted != null)
138+
})
139+
case StringType =>
140+
array.foreach(v => {
141+
val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String]
142+
listLiteralBuilder.addStringValues(if (casted != null) casted.toString else "")
143+
listLiteralBuilder.addNullMask(casted != null)
144+
})
145+
case _: DecimalType =>
146+
array
147+
.foreach(v => {
148+
val casted =
149+
v.asInstanceOf[Decimal]
150+
listLiteralBuilder.addDecimalValues(if (casted != null) {
151+
com.google.protobuf.ByteString
152+
.copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray)
153+
} else ByteString.EMPTY)
154+
listLiteralBuilder.addNullMask(casted != null)
155+
})
156+
case _: BinaryType =>
157+
array
158+
.foreach(v => {
159+
val casted =
160+
v.asInstanceOf[Array[Byte]]
161+
listLiteralBuilder.addBytesValues(if (casted != null) {
162+
com.google.protobuf.ByteString.copyFrom(casted)
163+
} else ByteString.EMPTY)
164+
listLiteralBuilder.addNullMask(casted != null)
165+
})
166+
}
167+
exprBuilder.setListVal(listLiteralBuilder.build())
168+
exprBuilder.setDatatype(serializeDataType(dataType).get)
169+
case dt =>
170+
withInfo(expr, s"Unexpected datatype '$dt' for literal value '$value'")
171+
return None
172+
}
173+
}
174+
175+
val dt = serializeDataType(dataType)
176+
177+
if (dt.isDefined) {
178+
exprBuilder.setDatatype(dt.get)
179+
180+
Some(
181+
ExprOuterClass.Expr
182+
.newBuilder()
183+
.setLiteral(exprBuilder)
184+
.build())
185+
} else {
186+
withInfo(expr, s"Unsupported datatype $dataType")
187+
None
188+
}
189+
190+
}
191+
}

0 commit comments

Comments
 (0)