Skip to content

Commit c6f01ca

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-22750][SQL] Reuse mutable states when possible
## What changes were proposed in this pull request? The PR introduces a new method `addImmutableStateIfNotExists ` to `CodeGenerator` to allow reusing and sharing the same global variable between different Expressions. This helps reducing the number of global variables needed, which is important to limit the impact on the constant pool. ## How was this patch tested? added UTs Author: Marco Gaido <[email protected]> Author: Marco Gaido <[email protected]> Closes #19940 from mgaido91/SPARK-22750.
1 parent c0abb1d commit c6f01ca

File tree

6 files changed

+80
-14
lines changed

6 files changed

+80
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
6666

6767
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
6868
val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
69-
val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask")
69+
val partitionMaskTerm = "partitionMask"
70+
ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
7071
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
7172
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
7273

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
4343
override protected def evalInternal(input: InternalRow): Int = partitionId
4444

4545
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
46-
val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId")
46+
val idTerm = "partitionId"
47+
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
4748
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
4849
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
4950
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@ class CodegenContext {
207207

208208
}
209209

210+
/**
211+
* A map containing the mutable states which have been defined so far using
212+
* `addImmutableStateIfNotExists`. Each entry contains the name of the mutable state as key and
213+
* its Java type and init code as value.
214+
*/
215+
private val immutableStates: mutable.Map[String, (String, String)] =
216+
mutable.Map.empty[String, (String, String)]
217+
210218
/**
211219
* Add a mutable state as a field to the generated class. c.f. the comments above.
212220
*
@@ -265,6 +273,38 @@ class CodegenContext {
265273
}
266274
}
267275

276+
/**
277+
* Add an immutable state as a field to the generated class only if it does not exist yet a field
278+
* with that name. This helps reducing the number of the generated class' fields, since the same
279+
* variable can be reused by many functions.
280+
*
281+
* Even though the added variables are not declared as final, they should never be reassigned in
282+
* the generated code to prevent errors and unexpected behaviors.
283+
*
284+
* Internally, this method calls `addMutableState`.
285+
*
286+
* @param javaType Java type of the field.
287+
* @param variableName Name of the field.
288+
* @param initFunc Function includes statement(s) to put into the init() method to initialize
289+
* this field. The argument is the name of the mutable state variable.
290+
*/
291+
def addImmutableStateIfNotExists(
292+
javaType: String,
293+
variableName: String,
294+
initFunc: String => String = _ => ""): Unit = {
295+
val existingImmutableState = immutableStates.get(variableName)
296+
if (existingImmutableState.isEmpty) {
297+
addMutableState(javaType, variableName, initFunc, useFreshName = false, forceInline = true)
298+
immutableStates(variableName) = (javaType, initFunc(variableName))
299+
} else {
300+
val (prevJavaType, prevInitCode) = existingImmutableState.get
301+
assert(prevJavaType == javaType, s"$variableName has already been defined with type " +
302+
s"$prevJavaType and now it is tried to define again with type $javaType.")
303+
assert(prevInitCode == initFunc(variableName), s"$variableName has already been defined " +
304+
s"with different initialization statements.")
305+
}
306+
}
307+
268308
/**
269309
* Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees
270310
* that the variable is safely stored, which is important for (potentially) byte array backed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas
443443
nullSafeCodeGen(ctx, ev, time => {
444444
val cal = classOf[Calendar].getName
445445
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
446-
val c = ctx.addMutableState(cal, "cal",
446+
val c = "calDayOfWeek"
447+
ctx.addImmutableStateIfNotExists(cal, c,
447448
v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
448449
s"""
449450
$c.setTimeInMillis($time * 1000L * 3600L * 24L);
@@ -484,8 +485,9 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
484485
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
485486
nullSafeCodeGen(ctx, ev, time => {
486487
val cal = classOf[Calendar].getName
488+
val c = "calWeekOfYear"
487489
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
488-
val c = ctx.addMutableState(cal, "cal", v =>
490+
ctx.addImmutableStateIfNotExists(cal, c, v =>
489491
s"""
490492
|$v = $cal.getInstance($dtu.getTimeZone("UTC"));
491493
|$v.setFirstDayOfWeek($cal.MONDAY);
@@ -1017,7 +1019,8 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
10171019
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
10181020
val tzTerm = ctx.addMutableState(tzClass, "tz",
10191021
v => s"""$v = $dtu.getTimeZone("$tz");""")
1020-
val utcTerm = ctx.addMutableState(tzClass, "utc",
1022+
val utcTerm = "tzUTC"
1023+
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
10211024
v => s"""$v = $dtu.getTimeZone("UTC");""")
10221025
val eval = left.genCode(ctx)
10231026
ev.copy(code = s"""
@@ -1193,7 +1196,8 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
11931196
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
11941197
val tzTerm = ctx.addMutableState(tzClass, "tz",
11951198
v => s"""$v = $dtu.getTimeZone("$tz");""")
1196-
val utcTerm = ctx.addMutableState(tzClass, "utc",
1199+
val utcTerm = "tzUTC"
1200+
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
11971201
v => s"""$v = $dtu.getTimeZone("UTC");""")
11981202
val eval = left.genCode(ctx)
11991203
ev.copy(code = s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,17 +1148,21 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
11481148

11491149
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
11501150
// Code to initialize the serializer.
1151-
val (serializerClass, serializerInstanceClass) = {
1151+
val (serializer, serializerClass, serializerInstanceClass) = {
11521152
if (kryo) {
1153-
(classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
1153+
("kryoSerializer",
1154+
classOf[KryoSerializer].getName,
1155+
classOf[KryoSerializerInstance].getName)
11541156
} else {
1155-
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
1157+
("javaSerializer",
1158+
classOf[JavaSerializer].getName,
1159+
classOf[JavaSerializerInstance].getName)
11561160
}
11571161
}
11581162
// try conf from env, otherwise create a new one
11591163
val env = s"${classOf[SparkEnv].getName}.get()"
11601164
val sparkConf = s"new ${classOf[SparkConf].getName}()"
1161-
val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v =>
1165+
ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
11621166
s"""
11631167
|if ($env == null) {
11641168
| $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
@@ -1193,17 +1197,21 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
11931197

11941198
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
11951199
// Code to initialize the serializer.
1196-
val (serializerClass, serializerInstanceClass) = {
1200+
val (serializer, serializerClass, serializerInstanceClass) = {
11971201
if (kryo) {
1198-
(classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
1202+
("kryoSerializer",
1203+
classOf[KryoSerializer].getName,
1204+
classOf[KryoSerializerInstance].getName)
11991205
} else {
1200-
(classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
1206+
("javaSerializer",
1207+
classOf[JavaSerializer].getName,
1208+
classOf[JavaSerializerInstance].getName)
12011209
}
12021210
}
12031211
// try conf from env, otherwise create a new one
12041212
val env = s"${classOf[SparkEnv].getName}.get()"
12051213
val sparkConf = s"new ${classOf[SparkConf].getName}()"
1206-
val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v =>
1214+
ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
12071215
s"""
12081216
|if ($env == null) {
12091217
| $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,4 +424,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
424424
assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10)
425425
assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10)
426426
}
427+
428+
test("SPARK-22750: addImmutableStateIfNotExists") {
429+
val ctx = new CodegenContext
430+
val mutableState1 = "field1"
431+
val mutableState2 = "field2"
432+
ctx.addImmutableStateIfNotExists("int", mutableState1)
433+
ctx.addImmutableStateIfNotExists("int", mutableState1)
434+
ctx.addImmutableStateIfNotExists("String", mutableState2)
435+
ctx.addImmutableStateIfNotExists("int", mutableState1)
436+
ctx.addImmutableStateIfNotExists("String", mutableState2)
437+
assert(ctx.inlinedMutableStates.length == 2)
438+
}
427439
}

0 commit comments

Comments
 (0)