Skip to content

Commit bcf7151

Browse files
committed
feat: support array_repeat
1 parent e823163 commit bcf7151

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1979,6 +1979,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
19791979
case _: ArrayIntersect => convert(CometArrayIntersect)
19801980
case _: ArrayJoin => convert(CometArrayJoin)
19811981
case _: ArraysOverlap => convert(CometArraysOverlap)
1982+
case _: ArrayRepeat => convert(CometArrayRepeat)
19821983
case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] =>
19831984
convert(CometArrayCompact)
19841985
case _: ArrayExcept =>
@@ -3068,7 +3069,7 @@ trait CometAggregateExpressionSerde {
30683069
* Convert a Spark expression into a protocol buffer representation that can be passed into
30693070
* native code.
30703071
*
3071-
* @param expr
3072+
* @param aggExpr
30723073
* The aggregate expression.
30733074
* @param expr
30743075
* The aggregate function.

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,20 @@ object CometArraysOverlap extends CometExpressionSerde with IncompatExpr {
179179
}
180180
}
181181

182+
object CometArrayRepeat extends CometExpressionSerde with IncompatExpr {
183+
override def convert(
184+
expr: Expression,
185+
inputs: Seq[Attribute],
186+
binding: Boolean): Option[ExprOuterClass.Expr] = {
187+
val leftArrayExprProto = exprToProto(expr.children.head, inputs, binding)
188+
val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding)
189+
190+
val arraysRepeatScalarExpr =
191+
scalarExprToProto("array_repeat", leftArrayExprProto, rightArrayExprProto)
192+
optExprWithInfo(arraysRepeatScalarExpr, expr, expr.children: _*)
193+
}
194+
}
195+
182196
object CometArrayCompact extends CometExpressionSerde with IncompatExpr {
183197
override def convert(
184198
expr: Expression,

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,4 +410,26 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
410410
}
411411
}
412412

413+
test("array_repeat") {
414+
withSQLConf(
415+
CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true",
416+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {
417+
Seq(true, false).foreach { dictionaryEnabled =>
418+
withTempDir { dir =>
419+
val path = new Path(dir.toURI.toString, "test.parquet")
420+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
421+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
422+
spark.sql("select * from t1").printSchema()
423+
424+
checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from t1"))
425+
// checkSparkAnswerAndOperator(
426+
// sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null"))
427+
// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null"))
428+
// checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, _3) from t1 where _3 is null"))
429+
// checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1"))
430+
// checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1"))
431+
}
432+
}
433+
}
434+
}
413435
}

0 commit comments

Comments
 (0)