Skip to content

Commit 235b69d

Browse files
authored
feat: supports array_distinct (#1923)
1 parent c51f977 commit 235b69d

File tree

4 files changed

+78
-33
lines changed

4 files changed

+78
-33
lines changed

docs/source/user-guide/expressions.md

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,22 @@ The following Spark expressions are currently available. Any known compatibility
186186

187187
## Arrays
188188

189-
| Expression | Notes |
190-
| -------------- | ------------ |
191-
| ArrayAppend | Experimental |
192-
| ArrayExcept | Experimental |
193-
| ArrayCompact | Experimental |
194-
| ArrayContains | Experimental |
195-
| ArrayInsert | Experimental |
196-
| ArrayIntersect | Experimental |
197-
| ArrayJoin | Experimental |
198-
| ArrayRemove | |
199-
| ArrayRepeat | Experimental |
200-
| ArraysOverlap | Experimental |
201-
| ElementAt | Arrays only |
202-
| GetArrayItem | |
189+
| Expression | Notes |
190+
|----------------|----------------------------------------------------------------------------------------------------------------------------------------|
191+
| ArrayAppend | Experimental |
192+
| ArrayCompact | Experimental |
193+
| ArrayContains | Experimental |
194+
| ArrayDistinct | Experimental: behaves differently than spark. Datafusion first sorts then removes duplicates while spark preserves the original order. |
195+
| ArrayExcept | Experimental |
196+
| ArrayInsert | Experimental |
197+
| ArrayIntersect | Experimental |
198+
| ArrayJoin | Experimental |
199+
| ArrayMax | Experimental |
200+
| ArrayRemove | |
201+
| ArrayRepeat | Experimental |
202+
| ArraysOverlap | Experimental |
203+
| ElementAt | Arrays only |
204+
| GetArrayItem | |
203205

204206
## Structs
205207

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
6868
private val exprSerdeMap: Map[Class[_], CometExpressionSerde] = Map(
6969
classOf[ArrayAppend] -> CometArrayAppend,
7070
classOf[ArrayContains] -> CometArrayContains,
71+
classOf[ArrayDistinct] -> CometArrayDistinct,
7172
classOf[ArrayExcept] -> CometArrayExcept,
7273
classOf[ArrayInsert] -> CometArrayInsert,
7374
classOf[ArrayIntersect] -> CometArrayIntersect,

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,19 @@ object CometArrayContains extends CometExpressionSerde with IncompatExpr {
150150
}
151151
}
152152

153+
object CometArrayDistinct extends CometExpressionSerde with IncompatExpr {
154+
override def convert(
155+
expr: Expression,
156+
inputs: Seq[Attribute],
157+
binding: Boolean): Option[ExprOuterClass.Expr] = {
158+
val arrayExprProto = exprToProto(expr.children.head, inputs, binding)
159+
160+
val arrayDistinctScalarExpr =
161+
scalarFunctionExprToProto("array_distinct", arrayExprProto)
162+
optExprWithInfo(arrayDistinctScalarExpr, expr)
163+
}
164+
}
165+
153166
object CometArrayIntersect extends CometExpressionSerde with IncompatExpr {
154167
override def convert(
155168
expr: Expression,
@@ -171,9 +184,9 @@ object CometArrayMax extends CometExpressionSerde {
171184
binding: Boolean): Option[ExprOuterClass.Expr] = {
172185
val arrayExprProto = exprToProto(expr.children.head, inputs, binding)
173186

174-
val arrayContainsScalarExpr =
187+
val arrayMaxScalarExpr =
175188
scalarFunctionExprToProto("array_max", arrayExprProto)
176-
optExprWithInfo(arrayContainsScalarExpr, expr)
189+
optExprWithInfo(arrayMaxScalarExpr, expr)
177190
}
178191
}
179192

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -232,24 +232,53 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
232232
}
233233
}
234234

235+
test("array_distinct") {
236+
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
237+
Seq(true, false).foreach { dictionaryEnabled =>
238+
withTempDir { dir =>
239+
val path = new Path(dir.toURI.toString, "test.parquet")
240+
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
241+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
242+
// The result needs to be in ascending order for checkSparkAnswerAndOperator to pass
243+
// because datafusion array_distinct sorts the elements and then removes the duplicates
244+
checkSparkAnswerAndOperator(
245+
spark.sql("SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM t1"))
246+
checkSparkAnswerAndOperator(
247+
spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1"))
248+
checkSparkAnswerAndOperator(spark.sql(
249+
"SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4, _4, _5) END)) FROM t1"))
250+
// NULL needs to be the first element for checkSparkAnswerAndOperator to pass because
251+
// datafusion array_distinct sorts the elements and then removes the duplicates
252+
checkSparkAnswerAndOperator(
253+
spark.sql(
254+
"SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1"))
255+
checkSparkAnswerAndOperator(spark.sql(
256+
"SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1"))
257+
}
258+
}
259+
}
260+
}
261+
235262
test("array_max") {
236-
withTempDir { dir =>
237-
val path = new Path(dir.toURI.toString, "test.parquet")
238-
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000)
239-
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
240-
checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array(_2, _3, _4)) FROM t1"))
241-
checkSparkAnswerAndOperator(
242-
spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1"));
243-
checkSparkAnswerAndOperator(
244-
spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1"));
245-
checkSparkAnswerAndOperator(
246-
spark.sql("SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1"))
247-
checkSparkAnswerAndOperator(
248-
spark.sql("SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1"))
249-
checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array()) FROM t1"))
250-
checkSparkAnswerAndOperator(
251-
spark.sql(
252-
"SELECT array_max(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1"))
263+
Seq(true, false).foreach { dictionaryEnabled =>
264+
withTempDir { dir =>
265+
val path = new Path(dir.toURI.toString, "test.parquet")
266+
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
267+
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
268+
checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array(_2, _3, _4)) FROM t1"))
269+
checkSparkAnswerAndOperator(
270+
spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1"))
271+
checkSparkAnswerAndOperator(
272+
spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1"))
273+
checkSparkAnswerAndOperator(
274+
spark.sql("SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1"))
275+
checkSparkAnswerAndOperator(
276+
spark.sql("SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1"))
277+
checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array()) FROM t1"))
278+
checkSparkAnswerAndOperator(
279+
spark.sql(
280+
"SELECT array_max(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1"))
281+
}
253282
}
254283
}
255284

0 commit comments

Comments
 (0)