Skip to content

Commit 96e947e

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-22569][SQL] Clean usage of addMutableState and splitExpressions
## What changes were proposed in this pull request? This PR is to clean the usage of addMutableState and splitExpressions 1. replace hardcoded type string to ctx.JAVA_BOOLEAN etc. 2. create a default value of the initCode for ctx.addMutableStats 3. Use named arguments when calling `splitExpressions ` ## How was this patch tested? The existing test cases Author: gatorsmile <[email protected]> Closes #19790 from gatorsmile/codeClean.
1 parent 9bdff0b commit 96e947e

20 files changed

+104
-83
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
6767
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
6868
val countTerm = ctx.freshName("count")
6969
val partitionMaskTerm = ctx.freshName("partitionMask")
70-
ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
71-
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
70+
ctx.addMutableState(ctx.JAVA_LONG, countTerm)
71+
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm)
7272
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
7373
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
7474

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
4444

4545
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
4646
val idTerm = ctx.freshName("partitionId")
47-
ctx.addMutableState(ctx.JAVA_INT, idTerm, "")
47+
ctx.addMutableState(ctx.JAVA_INT, idTerm)
4848
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
4949
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
5050
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,8 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression {
602602

603603
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
604604
val evalChildren = children.map(_.genCode(ctx))
605-
ctx.addMutableState("boolean", ev.isNull, "")
606-
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
605+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
606+
ctx.addMutableState(ctx.javaType(dataType), ev.value)
607607
def updateEval(eval: ExprCode): String = {
608608
s"""
609609
${eval.code}
@@ -668,8 +668,8 @@ case class Greatest(children: Seq[Expression]) extends Expression {
668668

669669
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
670670
val evalChildren = children.map(_.genCode(ctx))
671-
ctx.addMutableState("boolean", ev.isNull, "")
672-
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
671+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
672+
ctx.addMutableState(ctx.javaType(dataType), ev.value)
673673
def updateEval(eval: ExprCode): String = {
674674
s"""
675675
${eval.code}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,19 @@ class CodegenContext {
157157
val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
158158
mutable.ArrayBuffer.empty[(String, String, String)]
159159

160-
def addMutableState(javaType: String, variableName: String, initCode: String): Unit = {
160+
/**
161+
* Add a mutable state as a field to the generated class. c.f. the comments above.
162+
*
163+
* @param javaType Java type of the field. Note that short names can be used for some types,
164+
* e.g. InternalRow, UnsafeRow, UnsafeArrayData, etc. Other types will have to
165+
* specify the fully-qualified Java type name. See the code in doCompile() for
166+
* the list of default imports available.
167+
* Also, generic type arguments are accepted but ignored.
168+
* @param variableName Name of the field.
169+
* @param initCode The statement(s) to put into the init() method to initialize this field.
170+
* If left blank, the field will be default-initialized.
171+
*/
172+
def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = {
161173
mutableStates += ((javaType, variableName, initCode))
162174
}
163175

@@ -191,7 +203,7 @@ class CodegenContext {
191203
val initCodes = mutableStates.distinct.map(_._3 + "\n")
192204
// The generated initialization code may exceed 64kb function size limit in JVM if there are too
193205
// many mutable states, so split it into multiple functions.
194-
splitExpressions(initCodes, "init", Nil)
206+
splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil)
195207
}
196208

197209
/**
@@ -769,7 +781,7 @@ class CodegenContext {
769781
// Cannot split these expressions because they are not created from a row object.
770782
return expressions.mkString("\n")
771783
}
772-
splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil)
784+
splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil)
773785
}
774786

775787
/**
@@ -931,7 +943,7 @@ class CodegenContext {
931943
dataType: DataType,
932944
baseFuncName: String): (String, String, String) = {
933945
val globalIsNull = freshName("isNull")
934-
addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
946+
addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;")
935947
val globalValue = freshName("value")
936948
addMutableState(javaType(dataType), globalValue,
937949
s"$globalValue = ${defaultValue(dataType)};")
@@ -1038,7 +1050,7 @@ class CodegenContext {
10381050
// 2. Less code.
10391051
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
10401052
// at least two nodes) as the cost of doing it is expected to be low.
1041-
addMutableState("boolean", isNull, s"$isNull = false;")
1053+
addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;")
10421054
addMutableState(javaType(expr.dataType), value,
10431055
s"$value = ${defaultValue(expr.dataType)};")
10441056

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
6363
if (e.nullable) {
6464
val isNull = s"isNull_$i"
6565
val value = s"value_$i"
66-
ctx.addMutableState("boolean", isNull, s"$isNull = true;")
66+
ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;")
6767
ctx.addMutableState(ctx.javaType(e.dataType), value,
6868
s"$value = ${ctx.defaultValue(e.dataType)};")
6969
s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ private [sql] object GenArrayData {
120120
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
121121
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
122122
val baseOffset = Platform.BYTE_ARRAY_OFFSET
123-
ctx.addMutableState("UnsafeArrayData", arrayDataName, "")
123+
ctx.addMutableState("UnsafeArrayData", arrayDataName)
124124

125125
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
126126
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ abstract class HashExpression[E] extends Expression {
277277
}
278278
})
279279

280-
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
280+
ctx.addMutableState(ctx.javaType(dataType), ev.value)
281281
ev.copy(code = s"""
282282
${ev.value} = $seed;
283283
$childrenHash""")
@@ -616,8 +616,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
616616
s"\n$childHash = 0;"
617617
})
618618

619-
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
620-
ctx.addMutableState("int", childHash, s"$childHash = 0;")
619+
ctx.addMutableState(ctx.javaType(dataType), ev.value)
620+
ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;")
621621
ev.copy(code = s"""
622622
${ev.value} = $seed;
623623
$childrenHash""")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7272
}
7373

7474
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
75-
ctx.addMutableState("boolean", ev.isNull, "")
76-
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
75+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
76+
ctx.addMutableState(ctx.javaType(dataType), ev.value)
7777

7878
val evals = children.map { e =>
7979
val eval = e.genCode(ctx)
@@ -385,8 +385,10 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
385385
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
386386
evals.mkString("\n")
387387
} else {
388-
ctx.splitExpressions(evals, "atLeastNNonNulls",
389-
("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil,
388+
ctx.splitExpressions(
389+
expressions = evals,
390+
funcName = "atLeastNNonNulls",
391+
arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil,
390392
returnType = "int",
391393
makeSplitFunction = { body =>
392394
s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ trait InvokeLike extends Expression with NonSQLExpression {
6363

6464
val resultIsNull = if (needNullCheck) {
6565
val resultIsNull = ctx.freshName("resultIsNull")
66-
ctx.addMutableState("boolean", resultIsNull, "")
66+
ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull)
6767
resultIsNull
6868
} else {
6969
"false"
7070
}
7171
val argValues = arguments.map { e =>
7272
val argValue = ctx.freshName("argValue")
73-
ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
73+
ctx.addMutableState(ctx.javaType(e.dataType), argValue)
7474
argValue
7575
}
7676

@@ -548,7 +548,7 @@ case class MapObjects private(
548548

549549
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
550550
val elementJavaType = ctx.javaType(loopVarDataType)
551-
ctx.addMutableState(elementJavaType, loopValue, "")
551+
ctx.addMutableState(elementJavaType, loopValue)
552552
val genInputData = inputData.genCode(ctx)
553553
val genFunction = lambdaFunction.genCode(ctx)
554554
val dataLength = ctx.freshName("dataLength")
@@ -644,7 +644,7 @@ case class MapObjects private(
644644
}
645645

646646
val loopNullCheck = if (loopIsNull != "false") {
647-
ctx.addMutableState("boolean", loopIsNull, "")
647+
ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull)
648648
inputDataType match {
649649
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
650650
case _ => s"$loopIsNull = $loopValue == null;"
@@ -808,10 +808,10 @@ case class CatalystToExternalMap private(
808808

809809
val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
810810
val keyElementJavaType = ctx.javaType(mapType.keyType)
811-
ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
811+
ctx.addMutableState(keyElementJavaType, keyLoopValue)
812812
val genKeyFunction = keyLambdaFunction.genCode(ctx)
813813
val valueElementJavaType = ctx.javaType(mapType.valueType)
814-
ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
814+
ctx.addMutableState(valueElementJavaType, valueLoopValue)
815815
val genValueFunction = valueLambdaFunction.genCode(ctx)
816816
val genInputData = inputData.genCode(ctx)
817817
val dataLength = ctx.freshName("dataLength")
@@ -844,7 +844,7 @@ case class CatalystToExternalMap private(
844844
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
845845

846846
val valueLoopNullCheck = if (valueLoopIsNull != "false") {
847-
ctx.addMutableState("boolean", valueLoopIsNull, "")
847+
ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull)
848848
s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
849849
} else {
850850
""
@@ -994,8 +994,8 @@ case class ExternalMapToCatalyst private(
994994

995995
val keyElementJavaType = ctx.javaType(keyType)
996996
val valueElementJavaType = ctx.javaType(valueType)
997-
ctx.addMutableState(keyElementJavaType, key, "")
998-
ctx.addMutableState(valueElementJavaType, value, "")
997+
ctx.addMutableState(keyElementJavaType, key)
998+
ctx.addMutableState(valueElementJavaType, value)
999999

10001000
val (defineEntries, defineKeyValue) = child.dataType match {
10011001
case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
@@ -1031,14 +1031,14 @@ case class ExternalMapToCatalyst private(
10311031
}
10321032

10331033
val keyNullCheck = if (keyIsNull != "false") {
1034-
ctx.addMutableState("boolean", keyIsNull, "")
1034+
ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull)
10351035
s"$keyIsNull = $key == null;"
10361036
} else {
10371037
""
10381038
}
10391039

10401040
val valueNullCheck = if (valueIsNull != "false") {
1041-
ctx.addMutableState("boolean", valueIsNull, "")
1041+
ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull)
10421042
s"$valueIsNull = $value == null;"
10431043
} else {
10441044
""
@@ -1106,7 +1106,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
11061106
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
11071107
val rowClass = classOf[GenericRowWithSchema].getName
11081108
val values = ctx.freshName("values")
1109-
ctx.addMutableState("Object[]", values, "")
1109+
ctx.addMutableState("Object[]", values)
11101110

11111111
val childrenCodes = children.zipWithIndex.map { case (e, i) =>
11121112
val eval = e.genCode(ctx)
@@ -1244,7 +1244,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
12441244

12451245
val javaBeanInstance = ctx.freshName("javaBean")
12461246
val beanInstanceJavaType = ctx.javaType(beanInstance.dataType)
1247-
ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "")
1247+
ctx.addMutableState(beanInstanceJavaType, javaBeanInstance)
12481248

12491249
val initialize = setters.map {
12501250
case (setterMethod, fieldValue) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
236236
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
237237
val valueGen = value.genCode(ctx)
238238
val listGen = list.map(_.genCode(ctx))
239-
ctx.addMutableState("boolean", ev.value, "")
240-
ctx.addMutableState("boolean", ev.isNull, "")
239+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
240+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
241241
val valueArg = ctx.freshName("valueArg")
242242
val listCode = listGen.map(x =>
243243
s"""
@@ -253,7 +253,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
253253
""")
254254
val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
255255
val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil
256-
ctx.splitExpressions(listCode, "valueIn", args)
256+
ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args)
257257
} else {
258258
listCode.mkString("\n")
259259
}

0 commit comments

Comments
 (0)