Skip to content

Commit bffe60b

Browse files
Feat: Impl array flatten func (#2039)
1 parent 178ab5d commit bffe60b

File tree

3 files changed

+115
-22
lines changed

3 files changed

+115
-22
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
169169
classOf[DateAdd] -> CometDateAdd,
170170
classOf[DateSub] -> CometDateSub,
171171
classOf[TruncDate] -> CometTruncDate,
172-
classOf[TruncTimestamp] -> CometTruncTimestamp)
172+
classOf[TruncTimestamp] -> CometTruncTimestamp,
173+
classOf[Flatten] -> CometFlatten)
173174

174175
/**
175176
* Mapping of Spark aggregate expression class to Comet expression handler.

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,18 @@ package org.apache.comet.serde
2121

2222
import scala.annotation.tailrec
2323

24-
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, Expression, Literal}
24+
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, Expression, Flatten, Literal}
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types._
2727

2828
import org.apache.comet.CometSparkSessionExtensions.withInfo
2929
import org.apache.comet.serde.QueryPlanSerde._
3030
import org.apache.comet.shims.CometExprShim
3131

32-
object CometArrayRemove extends CometExpressionSerde[ArrayRemove] with CometExprShim {
33-
34-
/** Exposed for unit testing */
35-
@tailrec
36-
def isTypeSupported(dt: DataType): Boolean = {
37-
import DataTypes._
38-
dt match {
39-
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
40-
_: DecimalType | DateType | TimestampType | TimestampNTZType | StringType |
41-
BinaryType =>
42-
true
43-
case ArrayType(elementType, _) => isTypeSupported(elementType)
44-
case _: StructType =>
45-
// https://github.com/apache/datafusion-comet/issues/1307
46-
false
47-
case _ => false
48-
}
49-
}
32+
object CometArrayRemove
33+
extends CometExpressionSerde[ArrayRemove]
34+
with CometExprShim
35+
with ArraysBase {
5036

5137
override def convert(
5238
expr: ArrayRemove,
@@ -417,3 +403,40 @@ object CometCreateArray extends CometExpressionSerde[CreateArray] {
417403
}
418404
}
419405
}
406+
407+
object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase {
408+
409+
override def convert(
410+
expr: Flatten,
411+
inputs: Seq[Attribute],
412+
binding: Boolean): Option[ExprOuterClass.Expr] = {
413+
val inputTypes = expr.children.map(_.dataType).toSet
414+
for (dt <- inputTypes) {
415+
if (!isTypeSupported(dt)) {
416+
withInfo(expr, s"data type not supported: $dt")
417+
return None
418+
}
419+
}
420+
val flattenExprProto = exprToProto(expr.child, inputs, binding)
421+
val flattenScalarExpr = scalarFunctionExprToProto("flatten", flattenExprProto)
422+
optExprWithInfo(flattenScalarExpr, expr, expr.children: _*)
423+
}
424+
}
425+
426+
trait ArraysBase {
427+
428+
def isTypeSupported(dt: DataType): Boolean = {
429+
import DataTypes._
430+
dt match {
431+
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
432+
_: DecimalType | DateType | TimestampType | TimestampNTZType | StringType =>
433+
true
434+
case BinaryType => false
435+
case ArrayType(elementType, _) => isTypeSupported(elementType)
436+
case _: StructType =>
437+
// https://github.com/apache/datafusion-comet/issues/1307
438+
false
439+
case _ => false
440+
}
441+
}
442+
}

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.functions._
2929

3030
import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus}
3131
import org.apache.comet.DataTypeSupport.isComplexType
32-
import org.apache.comet.serde.CometArrayExcept
32+
import org.apache.comet.serde.{CometArrayExcept, CometArrayRemove, CometFlatten}
3333
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
3434

3535
class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
@@ -71,7 +71,11 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
7171
val table = spark.read.parquet(filename)
7272
table.createOrReplaceTempView("t1")
7373
// test with array of each column
74-
for (fieldName <- table.schema.fieldNames) {
74+
val fieldNames =
75+
table.schema.fields
76+
.filter(field => CometArrayRemove.isTypeSupported(field.dataType))
77+
.map(_.name)
78+
for (fieldName <- fieldNames) {
7579
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
7680
.createOrReplaceTempView("t2")
7781
val df = sql("SELECT array_remove(a, b) FROM t2")
@@ -623,4 +627,69 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
623627
}
624628
}
625629
}
630+
631+
test("flatten - test all types (native Parquet reader)") {
632+
withTempDir { dir =>
633+
val path = new Path(dir.toURI.toString, "test.parquet")
634+
val filename = path.toString
635+
val random = new Random(42)
636+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
637+
ParquetGenerator.makeParquetFile(
638+
random,
639+
spark,
640+
filename,
641+
100,
642+
DataGenOptions(
643+
allowNull = true,
644+
generateNegativeZero = true,
645+
generateArray = false,
646+
generateStruct = false,
647+
generateMap = false))
648+
}
649+
val table = spark.read.parquet(filename)
650+
table.createOrReplaceTempView("t1")
651+
val fieldNames =
652+
table.schema.fields
653+
.filter(field => CometFlatten.isTypeSupported(field.dataType))
654+
.map(_.name)
655+
for (fieldName <- fieldNames) {
656+
sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName)) as a FROM t1")
657+
.createOrReplaceTempView("t2")
658+
checkSparkAnswerAndOperator(sql("SELECT flatten(a) FROM t2"))
659+
}
660+
}
661+
}
662+
663+
test("flatten - test all types (convert from Parquet)") {
664+
withTempDir { dir =>
665+
val path = new Path(dir.toURI.toString, "test.parquet")
666+
val filename = path.toString
667+
val random = new Random(42)
668+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
669+
val options = DataGenOptions(
670+
allowNull = true,
671+
generateNegativeZero = true,
672+
generateArray = true,
673+
generateStruct = true,
674+
generateMap = false)
675+
ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
676+
}
677+
withSQLConf(
678+
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
679+
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
680+
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
681+
val table = spark.read.parquet(filename)
682+
table.createOrReplaceTempView("t1")
683+
val fieldNames =
684+
table.schema.fields
685+
.filter(field => CometFlatten.isTypeSupported(field.dataType))
686+
.map(_.name)
687+
for (fieldName <- fieldNames) {
688+
sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName)) as a FROM t1")
689+
.createOrReplaceTempView("t2")
690+
checkSparkAnswer(sql("SELECT flatten(a) FROM t2"))
691+
}
692+
}
693+
}
694+
}
626695
}

0 commit comments

Comments
 (0)