Skip to content

Commit 0dc71c0

Browse files
authored
chore: Refactor serde for more array and struct expressions (#2257)
1 parent 7e0ff1a commit 0dc71c0

File tree

3 files changed

+242
-165
lines changed

3 files changed

+242
-165
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
9595
classOf[ArraysOverlap] -> CometArraysOverlap,
9696
classOf[ArrayUnion] -> CometArrayUnion,
9797
classOf[CreateArray] -> CometCreateArray,
98+
classOf[GetArrayItem] -> CometGetArrayItem,
99+
classOf[ElementAt] -> CometElementAt,
98100
classOf[Ascii] -> CometScalarFunction("ascii"),
99101
classOf[ConcatWs] -> CometScalarFunction("concat_ws"),
100102
classOf[Chr] -> CometScalarFunction("char"),
@@ -170,6 +172,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
170172
classOf[DateSub] -> CometDateSub,
171173
classOf[TruncDate] -> CometTruncDate,
172174
classOf[TruncTimestamp] -> CometTruncTimestamp,
175+
classOf[CreateNamedStruct] -> CometCreateNamedStruct,
176+
classOf[GetStructField] -> CometGetStructField,
177+
classOf[GetArrayStructFields] -> CometGetArrayStructFields,
178+
classOf[StructsToJson] -> CometStructsToJson,
173179
classOf[Flatten] -> CometFlatten,
174180
classOf[Atan2] -> CometAtan2,
175181
classOf[Ceil] -> CometCeil,
@@ -922,66 +928,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
922928
None
923929
}
924930

925-
case StructsToJson(options, child, timezoneId) =>
926-
if (options.nonEmpty) {
927-
withInfo(expr, "StructsToJson with options is not supported")
928-
None
929-
} else {
930-
931-
def isSupportedType(dt: DataType): Boolean = {
932-
dt match {
933-
case StructType(fields) =>
934-
fields.forall(f => isSupportedType(f.dataType))
935-
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
936-
DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType |
937-
DataTypes.DoubleType | DataTypes.StringType =>
938-
true
939-
case DataTypes.DateType | DataTypes.TimestampType =>
940-
// TODO implement these types with tests for formatting options and timezone
941-
false
942-
case _: MapType | _: ArrayType =>
943-
// Spark supports map and array in StructsToJson but this is not yet
944-
// implemented in Comet
945-
false
946-
case _ => false
947-
}
948-
}
949-
950-
val isSupported = child.dataType match {
951-
case s: StructType =>
952-
s.fields.forall(f => isSupportedType(f.dataType))
953-
case _: MapType | _: ArrayType =>
954-
// Spark supports map and array in StructsToJson but this is not yet
955-
// implemented in Comet
956-
false
957-
case _ =>
958-
false
959-
}
960-
961-
if (isSupported) {
962-
exprToProtoInternal(child, inputs, binding) match {
963-
case Some(p) =>
964-
val toJson = ExprOuterClass.ToJson
965-
.newBuilder()
966-
.setChild(p)
967-
.setTimezone(timezoneId.getOrElse("UTC"))
968-
.setIgnoreNullFields(true)
969-
.build()
970-
Some(
971-
ExprOuterClass.Expr
972-
.newBuilder()
973-
.setToJson(toJson)
974-
.build())
975-
case _ =>
976-
withInfo(expr, child)
977-
None
978-
}
979-
} else {
980-
withInfo(expr, "Unsupported data type", child)
981-
None
982-
}
983-
}
984-
985931
case SortOrder(child, direction, nullOrdering, _) =>
986932
val childExpr = exprToProtoInternal(child, inputs, binding)
987933

@@ -1336,110 +1282,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
13361282
withInfo(expr, bloomFilter, value)
13371283
None
13381284
}
1339-
1340-
case struct @ CreateNamedStruct(_) =>
1341-
if (struct.names.length != struct.names.distinct.length) {
1342-
withInfo(expr, "CreateNamedStruct with duplicate field names are not supported")
1343-
return None
1344-
}
1345-
1346-
val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, binding))
1347-
1348-
if (valExprs.forall(_.isDefined)) {
1349-
val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
1350-
structBuilder.addAllValues(valExprs.map(_.get).asJava)
1351-
structBuilder.addAllNames(struct.names.map(_.toString).asJava)
1352-
1353-
Some(
1354-
ExprOuterClass.Expr
1355-
.newBuilder()
1356-
.setCreateNamedStruct(structBuilder)
1357-
.build())
1358-
} else {
1359-
withInfo(expr, "unsupported arguments for CreateNamedStruct", struct.valExprs: _*)
1360-
None
1361-
}
1362-
1363-
case GetStructField(child, ordinal, _) =>
1364-
exprToProtoInternal(child, inputs, binding).map { childExpr =>
1365-
val getStructFieldBuilder = ExprOuterClass.GetStructField
1366-
.newBuilder()
1367-
.setChild(childExpr)
1368-
.setOrdinal(ordinal)
1369-
1370-
ExprOuterClass.Expr
1371-
.newBuilder()
1372-
.setGetStructField(getStructFieldBuilder)
1373-
.build()
1374-
}
1375-
1376-
case GetArrayItem(child, ordinal, failOnError) =>
1377-
val childExpr = exprToProtoInternal(child, inputs, binding)
1378-
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
1379-
1380-
if (childExpr.isDefined && ordinalExpr.isDefined) {
1381-
val listExtractBuilder = ExprOuterClass.ListExtract
1382-
.newBuilder()
1383-
.setChild(childExpr.get)
1384-
.setOrdinal(ordinalExpr.get)
1385-
.setOneBased(false)
1386-
.setFailOnError(failOnError)
1387-
1388-
Some(
1389-
ExprOuterClass.Expr
1390-
.newBuilder()
1391-
.setListExtract(listExtractBuilder)
1392-
.build())
1393-
} else {
1394-
withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal)
1395-
None
1396-
}
1397-
1398-
case ElementAt(child, ordinal, defaultValue, failOnError)
1399-
if child.dataType.isInstanceOf[ArrayType] =>
1400-
val childExpr = exprToProtoInternal(child, inputs, binding)
1401-
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
1402-
val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding))
1403-
1404-
if (childExpr.isDefined && ordinalExpr.isDefined &&
1405-
defaultExpr.isDefined == defaultValue.isDefined) {
1406-
val arrayExtractBuilder = ExprOuterClass.ListExtract
1407-
.newBuilder()
1408-
.setChild(childExpr.get)
1409-
.setOrdinal(ordinalExpr.get)
1410-
.setOneBased(true)
1411-
.setFailOnError(failOnError)
1412-
1413-
defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
1414-
1415-
Some(
1416-
ExprOuterClass.Expr
1417-
.newBuilder()
1418-
.setListExtract(arrayExtractBuilder)
1419-
.build())
1420-
} else {
1421-
withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
1422-
None
1423-
}
1424-
1425-
case GetArrayStructFields(child, _, ordinal, _, _) =>
1426-
val childExpr = exprToProtoInternal(child, inputs, binding)
1427-
1428-
if (childExpr.isDefined) {
1429-
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
1430-
.newBuilder()
1431-
.setChild(childExpr.get)
1432-
.setOrdinal(ordinal)
1433-
1434-
Some(
1435-
ExprOuterClass.Expr
1436-
.newBuilder()
1437-
.setGetArrayStructFields(arrayStructFieldsBuilder)
1438-
.build())
1439-
} else {
1440-
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
1441-
None
1442-
}
14431285
case af @ ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] =>
14441286
convert(af, CometArrayCompact)
14451287
case expr =>

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.annotation.tailrec
2323

24-
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, Expression, Flatten, Literal}
24+
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal}
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types._
2727

@@ -404,6 +404,72 @@ object CometCreateArray extends CometExpressionSerde[CreateArray] {
404404
}
405405
}
406406

407+
object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] {
408+
override def convert(
409+
expr: GetArrayItem,
410+
inputs: Seq[Attribute],
411+
binding: Boolean): Option[ExprOuterClass.Expr] = {
412+
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
413+
val ordinalExpr = exprToProtoInternal(expr.ordinal, inputs, binding)
414+
415+
if (childExpr.isDefined && ordinalExpr.isDefined) {
416+
val listExtractBuilder = ExprOuterClass.ListExtract
417+
.newBuilder()
418+
.setChild(childExpr.get)
419+
.setOrdinal(ordinalExpr.get)
420+
.setOneBased(false)
421+
.setFailOnError(expr.failOnError)
422+
423+
Some(
424+
ExprOuterClass.Expr
425+
.newBuilder()
426+
.setListExtract(listExtractBuilder)
427+
.build())
428+
} else {
429+
withInfo(expr, "unsupported arguments for GetArrayItem", expr.child, expr.ordinal)
430+
None
431+
}
432+
}
433+
}
434+
435+
object CometElementAt extends CometExpressionSerde[ElementAt] {
436+
437+
override def convert(
438+
expr: ElementAt,
439+
inputs: Seq[Attribute],
440+
binding: Boolean): Option[ExprOuterClass.Expr] = {
441+
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
442+
val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding)
443+
val defaultExpr = expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding))
444+
445+
if (!expr.left.dataType.isInstanceOf[ArrayType]) {
446+
withInfo(expr, "Input is not an array")
447+
return None
448+
}
449+
450+
if (childExpr.isDefined && ordinalExpr.isDefined &&
451+
defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) {
452+
val arrayExtractBuilder = ExprOuterClass.ListExtract
453+
.newBuilder()
454+
.setChild(childExpr.get)
455+
.setOrdinal(ordinalExpr.get)
456+
.setOneBased(true)
457+
.setFailOnError(expr.failOnError)
458+
459+
defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
460+
461+
Some(
462+
ExprOuterClass.Expr
463+
.newBuilder()
464+
.setListExtract(arrayExtractBuilder)
465+
.build())
466+
} else {
467+
withInfo(expr, "unsupported arguments for ElementAt", expr.left, expr.right)
468+
None
469+
}
470+
}
471+
}
472+
407473
object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase {
408474

409475
override def convert(

0 commit comments

Comments
 (0)