Skip to content

Commit 594ac4f

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-24633][SQL] Fix codegen when split is required for arrays_zip
## What changes were proposed in this pull request? In function array_zip, when split is required by the high number of arguments, a codegen error can happen. The PR fixes codegen for cases when splitting the code is required. ## How was this patch tested? added UT Author: Marco Gaido <[email protected]> Closes apache#21621 from mgaido91/SPARK-24633.
1 parent bac50aa commit 594ac4f

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
200200
""".stripMargin
201201
}
202202

203-
val splittedGetValuesAndCardinalities = ctx.splitExpressions(
203+
val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs(
204204
expressions = getValuesAndCardinalities,
205205
funcName = "getValuesAndCardinalities",
206206
returnType = "int",
@@ -210,7 +210,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
210210
|return $biggestCardinality;
211211
""".stripMargin,
212212
foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
213-
arguments =
213+
extraArguments =
214214
("ArrayData[]", arrVals) ::
215215
("int", biggestCardinality) :: Nil)
216216

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
556556
checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8)
557557
}
558558

559+
test("SPARK-24633: arrays_zip splits input processing correctly") {
560+
Seq("true", "false").foreach { wholestageCodegenEnabled =>
561+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) {
562+
val df = spark.range(1)
563+
val exprs = (0 to 5).map(x => array($"id" + lit(x)))
564+
checkAnswer(df.select(arrays_zip(exprs: _*)),
565+
Row(Seq(Row(0, 1, 2, 3, 4, 5))))
566+
}
567+
}
568+
}
569+
559570
test("map size function") {
560571
val df = Seq(
561572
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),

0 commit comments

Comments
 (0)