Skip to content

Commit 1cce1a3

Browse files
eatoncysgatorsmile
authored andcommitted
[SPARK-21603][SQL] The wholestage codegen will be much slower then that is closed when the function is too long
## What changes were proposed in this pull request? Close the whole stage codegen when the function lines is longer than the maxlines which will be setted by spark.sql.codegen.MaxFunctionLength parameter, because when the function is too long , it will not get the JIT optimizing. A benchmark test result is 10x slower when the generated function is too long : ignore("max function length of wholestagecodegen") { val N = 20 << 15 val benchmark = new Benchmark("max function length of wholestagecodegen", N) def f(): Unit = sparkSession.range(N) .selectExpr( "id", "(id & 1023) as k1", "cast(id & 1023 as double) as k2", "cast(id & 1023 as int) as k3", "case when id > 100 and id <= 200 then 1 else 0 end as v1", "case when id > 200 and id <= 300 then 1 else 0 end as v2", "case when id > 300 and id <= 400 then 1 else 0 end as v3", "case when id > 400 and id <= 500 then 1 else 0 end as v4", "case when id > 500 and id <= 600 then 1 else 0 end as v5", "case when id > 600 and id <= 700 then 1 else 0 end as v6", "case when id > 700 and id <= 800 then 1 else 0 end as v7", "case when id > 800 and id <= 900 then 1 else 0 end as v8", "case when id > 900 and id <= 1000 then 1 else 0 end as v9", "case when id > 1000 and id <= 1100 then 1 else 0 end as v10", "case when id > 1100 and id <= 1200 then 1 else 0 end as v11", "case when id > 1200 and id <= 1300 then 1 else 0 end as v12", "case when id > 1300 and id <= 1400 then 1 else 0 end as v13", "case when id > 1400 and id <= 1500 then 1 else 0 end as v14", "case when id > 1500 and id <= 1600 then 1 else 0 end as v15", "case when id > 1600 and id <= 1700 then 1 else 0 end as v16", "case when id > 1700 and id <= 1800 then 1 else 0 end as v17", "case when id > 1800 and id <= 1900 then 1 else 0 end as v18") .groupBy("k1", "k2", "k3") .sum() .collect() benchmark.addCase(s"codegen = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") f() } benchmark.addCase(s"codegen = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.MaxFunctionLength", "10000") f() } benchmark.run() /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 Intel64 Family 6 Model 58 Stepping 9, GenuineIntel max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ codegen = F 443 / 507 1.5 676.0 1.0X codegen = T 3279 / 3283 0.2 5002.6 0.1X */ } ## How was this patch tested? Run the unit test Author: 10129659 <[email protected]> Closes apache#18810 from eatoncys/codegen.
1 parent adf005d commit 1cce1a3

File tree

7 files changed

+193
-0
lines changed

7 files changed

+193
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ object CodeFormatter {
8989
}
9090
new CodeAndComment(code.result().trim(), map)
9191
}
92+
93+
def stripExtraNewLinesAndComments(input: String): String = {
94+
val commentReg =
95+
("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/
96+
"""([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment
97+
val codeWithoutComment = commentReg.replaceAllIn(input, "")
98+
codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines
99+
}
92100
}
93101

94102
private class CodeFormatter {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,20 @@ class CodegenContext {
355355
*/
356356
private val placeHolderToComments = new mutable.HashMap[String, String]
357357

358+
/**
359+
* It will count the lines of every Java function generated by whole-stage codegen,
360+
* if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction,
361+
* it will return true.
362+
*/
363+
def isTooLongGeneratedFunction: Boolean = {
364+
classFunctions.values.exists { _.values.exists {
365+
code =>
366+
val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code)
367+
codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction
368+
}
369+
}
370+
}
371+
358372
/**
359373
* Returns a term name that is unique within this instance of a `CodegenContext`.
360374
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,16 @@ object SQLConf {
572572
"disable logging or -1 to apply no limit.")
573573
.createWithDefault(1000)
574574

575+
val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction")
576+
.internal()
577+
.doc("The maximum lines of a single Java function generated by whole-stage codegen. " +
578+
"When the generated function exceeds this threshold, " +
579+
"the whole-stage codegen is deactivated for this subtree of the current query plan. " +
580+
"The default value 2667 is the max length of byte code JIT supported " +
581+
"for a single function(8000) divided by 3.")
582+
.intConf
583+
.createWithDefault(2667)
584+
575585
val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
576586
.doc("The maximum number of bytes to pack into a single partition when reading files.")
577587
.longConf
@@ -1037,6 +1047,8 @@ class SQLConf extends Serializable with Logging {
10371047

10381048
def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
10391049

1050+
def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION)
1051+
10401052
def tableRelationCacheSize: Int =
10411053
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)
10421054

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,38 @@ class CodeFormatterSuite extends SparkFunSuite {
5353
assert(reducedCode.body === "/*project_c4*/")
5454
}
5555

56+
test("removing extra new lines and comments") {
57+
val code =
58+
"""
59+
|/*
60+
| * multi
61+
| * line
62+
| * comments
63+
| */
64+
|
65+
|public function() {
66+
|/*comment*/
67+
| /*comment_with_space*/
68+
|code_body
69+
|//comment
70+
|code_body
71+
| //comment_with_space
72+
|
73+
|code_body
74+
|}
75+
""".stripMargin
76+
77+
val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code)
78+
assert(reducedCode ===
79+
"""
80+
|public function() {
81+
|code_body
82+
|code_body
83+
|code_body
84+
|}
85+
""".stripMargin)
86+
}
87+
5688
testCase("basic example") {
5789
"""
5890
|class A {

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,14 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
370370

371371
override def doExecute(): RDD[InternalRow] = {
372372
val (ctx, cleanedSource) = doCodeGen()
373+
if (ctx.isTooLongGeneratedFunction) {
374+
logWarning("Found too long generated codes and JIT optimization might not work, " +
375+
"Whole-stage codegen disabled for this plan, " +
376+
"You can change the config spark.sql.codegen.MaxFunctionLength " +
377+
"to adjust the function length limit:\n "
378+
+ s"$treeString")
379+
return child.execute()
380+
}
373381
// try to compile and fallback if it failed
374382
try {
375383
CodeGenerator.compile(cleanedSource)

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.sql.{Column, Dataset, Row}
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2222
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2324
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
2425
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2526
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
@@ -149,4 +150,60 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
149150
assert(df.collect() === Array(Row(1), Row(2)))
150151
}
151152
}
153+
154+
def genGroupByCodeGenContext(caseNum: Int): CodegenContext = {
155+
val caseExp = (1 to caseNum).map { i =>
156+
s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i"
157+
}.toList
158+
val keyExp = List(
159+
"id",
160+
"(id & 1023) as k1",
161+
"cast(id & 1023 as double) as k2",
162+
"cast(id & 1023 as int) as k3")
163+
164+
val ds = spark.range(10)
165+
.selectExpr(keyExp:::caseExp: _*)
166+
.groupBy("k1", "k2", "k3")
167+
.sum()
168+
val plan = ds.queryExecution.executedPlan
169+
170+
val wholeStageCodeGenExec = plan.find(p => p match {
171+
case wp: WholeStageCodegenExec => wp.child match {
172+
case hp: HashAggregateExec if (hp.child.isInstanceOf[ProjectExec]) => true
173+
case _ => false
174+
}
175+
case _ => false
176+
})
177+
178+
assert(wholeStageCodeGenExec.isDefined)
179+
wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1
180+
}
181+
182+
test("SPARK-21603 check there is a too long generated function") {
183+
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") {
184+
val ctx = genGroupByCodeGenContext(30)
185+
assert(ctx.isTooLongGeneratedFunction === true)
186+
}
187+
}
188+
189+
test("SPARK-21603 check there is not a too long generated function") {
190+
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") {
191+
val ctx = genGroupByCodeGenContext(1)
192+
assert(ctx.isTooLongGeneratedFunction === false)
193+
}
194+
}
195+
196+
test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") {
197+
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) {
198+
val ctx = genGroupByCodeGenContext(30)
199+
assert(ctx.isTooLongGeneratedFunction === false)
200+
}
201+
}
202+
203+
test("SPARK-21603 check there is a too long generated function when threshold is 0") {
204+
withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") {
205+
val ctx = genGroupByCodeGenContext(1)
206+
assert(ctx.isTooLongGeneratedFunction === true)
207+
}
208+
}
152209
}

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,68 @@ class AggregateBenchmark extends BenchmarkBase {
301301
*/
302302
}
303303

304+
ignore("max function length of wholestagecodegen") {
305+
val N = 20 << 15
306+
307+
val benchmark = new Benchmark("max function length of wholestagecodegen", N)
308+
def f(): Unit = sparkSession.range(N)
309+
.selectExpr(
310+
"id",
311+
"(id & 1023) as k1",
312+
"cast(id & 1023 as double) as k2",
313+
"cast(id & 1023 as int) as k3",
314+
"case when id > 100 and id <= 200 then 1 else 0 end as v1",
315+
"case when id > 200 and id <= 300 then 1 else 0 end as v2",
316+
"case when id > 300 and id <= 400 then 1 else 0 end as v3",
317+
"case when id > 400 and id <= 500 then 1 else 0 end as v4",
318+
"case when id > 500 and id <= 600 then 1 else 0 end as v5",
319+
"case when id > 600 and id <= 700 then 1 else 0 end as v6",
320+
"case when id > 700 and id <= 800 then 1 else 0 end as v7",
321+
"case when id > 800 and id <= 900 then 1 else 0 end as v8",
322+
"case when id > 900 and id <= 1000 then 1 else 0 end as v9",
323+
"case when id > 1000 and id <= 1100 then 1 else 0 end as v10",
324+
"case when id > 1100 and id <= 1200 then 1 else 0 end as v11",
325+
"case when id > 1200 and id <= 1300 then 1 else 0 end as v12",
326+
"case when id > 1300 and id <= 1400 then 1 else 0 end as v13",
327+
"case when id > 1400 and id <= 1500 then 1 else 0 end as v14",
328+
"case when id > 1500 and id <= 1600 then 1 else 0 end as v15",
329+
"case when id > 1600 and id <= 1700 then 1 else 0 end as v16",
330+
"case when id > 1700 and id <= 1800 then 1 else 0 end as v17",
331+
"case when id > 1800 and id <= 1900 then 1 else 0 end as v18")
332+
.groupBy("k1", "k2", "k3")
333+
.sum()
334+
.collect()
335+
336+
benchmark.addCase(s"codegen = F") { iter =>
337+
sparkSession.conf.set("spark.sql.codegen.wholeStage", "false")
338+
f()
339+
}
340+
341+
benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter =>
342+
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
343+
sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000")
344+
f()
345+
}
346+
347+
benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter =>
348+
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
349+
sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500")
350+
f()
351+
}
352+
353+
benchmark.run()
354+
355+
/*
356+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1
357+
Intel64 Family 6 Model 58 Stepping 9, GenuineIntel
358+
max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
359+
----------------------------------------------------------------------------------------------
360+
codegen = F 462 / 533 1.4 704.4 1.0X
361+
codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X
362+
codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X
363+
*/
364+
}
365+
304366

305367
ignore("cube") {
306368
val N = 5 << 20

0 commit comments

Comments
 (0)