Skip to content

Commit 41c6f36

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-22549][SQL] Fix 64KB JVM bytecode limit problem with concat_ws
## What changes were proposed in this pull request? This PR changes `concat_ws` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `concat_ws` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki <[email protected]> Closes #19777 from kiszk/SPARK-22549.
1 parent c13b60e commit 41c6f36

File tree

2 files changed

+89
-24
lines changed

2 files changed

+89
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,34 @@ case class ConcatWs(children: Seq[Expression])
137137
if (children.forall(_.dataType == StringType)) {
138138
// All children are strings. In that case we can construct a fixed size array.
139139
val evals = children.map(_.genCode(ctx))
140-
141-
val inputs = evals.map { eval =>
142-
s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
143-
}.mkString(", ")
144-
145-
ev.copy(evals.map(_.code).mkString("\n") + s"""
146-
UTF8String ${ev.value} = UTF8String.concatWs($inputs);
140+
val separator = evals.head
141+
val strings = evals.tail
142+
val numArgs = strings.length
143+
val args = ctx.freshName("args")
144+
145+
val inputs = strings.zipWithIndex.map { case (eval, index) =>
146+
if (eval.isNull != "true") {
147+
s"""
148+
${eval.code}
149+
if (!${eval.isNull}) {
150+
$args[$index] = ${eval.value};
151+
}
152+
"""
153+
} else {
154+
""
155+
}
156+
}
157+
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
158+
ctx.splitExpressions(inputs, "valueConcatWs",
159+
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
160+
} else {
161+
inputs.mkString("\n")
162+
}
163+
ev.copy(s"""
164+
UTF8String[] $args = new UTF8String[$numArgs];
165+
${separator.code}
166+
$codes
167+
UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
147168
boolean ${ev.isNull} = ${ev.value} == null;
148169
""")
149170
} else {
@@ -156,32 +177,63 @@ case class ConcatWs(children: Seq[Expression])
156177
child.dataType match {
157178
case StringType =>
158179
("", // we count all the StringType arguments num at once below.
159-
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};")
180+
if (eval.isNull == "true") {
181+
""
182+
} else {
183+
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
184+
})
160185
case _: ArrayType =>
161186
val size = ctx.freshName("n")
162-
(s"""
163-
if (!${eval.isNull}) {
164-
$varargNum += ${eval.value}.numElements();
165-
}
166-
""",
167-
s"""
168-
if (!${eval.isNull}) {
169-
final int $size = ${eval.value}.numElements();
170-
for (int j = 0; j < $size; j ++) {
171-
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
172-
}
187+
if (eval.isNull == "true") {
188+
("", "")
189+
} else {
190+
(s"""
191+
if (!${eval.isNull}) {
192+
$varargNum += ${eval.value}.numElements();
193+
}
194+
""",
195+
s"""
196+
if (!${eval.isNull}) {
197+
final int $size = ${eval.value}.numElements();
198+
for (int j = 0; j < $size; j ++) {
199+
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
200+
}
201+
}
202+
""")
173203
}
174-
""")
175204
}
176205
}.unzip
177206

178-
ev.copy(evals.map(_.code).mkString("\n") +
179-
s"""
207+
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
208+
val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs",
209+
("InternalRow", ctx.INPUT_ROW) :: Nil,
210+
"int",
211+
{ body =>
212+
s"""
213+
int $varargNum = 0;
214+
$body
215+
return $varargNum;
216+
"""
217+
},
218+
_.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
219+
val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs",
220+
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
221+
"int",
222+
{ body =>
223+
s"""
224+
$body
225+
return $idxInVararg;
226+
"""
227+
},
228+
_.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
229+
ev.copy(
230+
s"""
231+
$codes
180232
int $varargNum = ${children.count(_.dataType == StringType) - 1};
181233
int $idxInVararg = 0;
182-
${varargCount.mkString("\n")}
234+
$varargCounts
183235
UTF8String[] $array = new UTF8String[$varargNum];
184-
${varargBuild.mkString("\n")}
236+
$varargBuilds
185237
UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
186238
boolean ${ev.isNull} = ${ev.value} == null;
187239
""")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
8080
// scalastyle:on
8181
}
8282

83+
test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") {
84+
val N = 5000
85+
val sepExpr = Literal.create("#", StringType)
86+
val strings1 = (1 to N).map(x => s"s$x")
87+
val inputsExpr1 = strings1.map(Literal.create(_, StringType))
88+
checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)
89+
90+
val strings2 = (1 to N).map(x => Seq(s"s$x"))
91+
val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
92+
checkEvaluation(
93+
ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
94+
}
95+
8396
test("elt") {
8497
def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
8598
checkEvaluation(

0 commit comments

Comments
 (0)