Skip to content

Commit c5583fd

Browse files
kiszkueshin
authored andcommitted
[SPARK-23466][SQL] Remove redundant null checks in generated Java code by GenerateUnsafeProjection
## What changes were proposed in this pull request? This PR works for one of TODOs in `GenerateUnsafeProjection` "if the nullability of field is correct, we can use it to save null check" to simplify generated code. When `nullable=false` in `DataType`, `GenerateUnsafeProjection` removed code for null checks in the generated Java code. ## How was this patch tested? Added new test cases into `GenerateUnsafeProjectionSuite` Closes apache#20637 from kiszk/SPARK-23466. Authored-by: Kazuaki Ishizaki <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent e1d72f2 commit c5583fd

File tree

3 files changed

+117
-33
lines changed

3 files changed

+117
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import org.apache.spark.sql.types._
3232
*/
3333
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
3434

35+
case class Schema(dataType: DataType, nullable: Boolean)
36+
3537
/** Returns true iff we support this data type. */
3638
def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match {
3739
case NullType => true
@@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4345
case _ => false
4446
}
4547

46-
// TODO: if the nullability of field is correct, we can use it to save null check.
4748
private def writeStructToBuffer(
4849
ctx: CodegenContext,
4950
input: String,
5051
index: String,
51-
fieldTypes: Seq[DataType],
52+
schemas: Seq[Schema],
5253
rowWriter: String): String = {
5354
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
5455
val tmpInput = ctx.freshName("tmpInput")
55-
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
56-
ExprCode(
57-
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
58-
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
56+
val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) =>
57+
val isNull = if (nullable) {
58+
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)")
59+
} else {
60+
FalseLiteral
61+
}
62+
ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
5963
}
6064

6165
val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7074
| // Remember the current cursor so that we can calculate how many bytes are
7175
| // written later.
7276
| final int $previousCursor = $rowWriter.cursor();
73-
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)}
77+
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)}
7478
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
7579
|}
7680
""".stripMargin
@@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
8084
ctx: CodegenContext,
8185
row: String,
8286
inputs: Seq[ExprCode],
83-
inputTypes: Seq[DataType],
87+
schemas: Seq[Schema],
8488
rowWriter: String,
8589
isTopLevel: Boolean = false): String = {
8690
val resetWriter = if (isTopLevel) {
@@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
98102
s"$rowWriter.resetRowWriter();"
99103
}
100104

101-
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
102-
case ((input, dataType), index) =>
105+
val writeFields = inputs.zip(schemas).zipWithIndex.map {
106+
case ((input, Schema(dataType, nullable)), index) =>
103107
val dt = UserDefinedType.sqlType(dataType)
104108

105109
val setNull = dt match {
@@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
110114
}
111115

112116
val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
113-
if (input.isNull == FalseLiteral) {
117+
if (!nullable) {
114118
s"""
115119
|${input.code}
116120
|${writeField.trim}
@@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
143147
""".stripMargin
144148
}
145149

146-
// TODO: if the nullability of array element is correct, we can use it to save null check.
147150
private def writeArrayToBuffer(
148151
ctx: CodegenContext,
149152
input: String,
150153
elementType: DataType,
154+
containsNull: Boolean,
151155
rowWriter: String): String = {
152156
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
153157
val tmpInput = ctx.freshName("tmpInput")
@@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
170174

171175
val element = CodeGenerator.getValue(tmpInput, et, index)
172176

177+
val elementAssignment = if (containsNull) {
178+
s"""
179+
|if ($tmpInput.isNullAt($index)) {
180+
| $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
181+
|} else {
182+
| ${writeElement(ctx, element, index, et, arrayWriter)}
183+
|}
184+
""".stripMargin
185+
} else {
186+
writeElement(ctx, element, index, et, arrayWriter)
187+
}
188+
173189
s"""
174190
|final ArrayData $tmpInput = $input;
175191
|if ($tmpInput instanceof UnsafeArrayData) {
@@ -179,30 +195,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
179195
| $arrayWriter.initialize($numElements);
180196
|
181197
| for (int $index = 0; $index < $numElements; $index++) {
182-
| if ($tmpInput.isNullAt($index)) {
183-
| $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
184-
| } else {
185-
| ${writeElement(ctx, element, index, et, arrayWriter)}
186-
| }
198+
| $elementAssignment
187199
| }
188200
|}
189201
""".stripMargin
190202
}
191203

192-
// TODO: if the nullability of value element is correct, we can use it to save null check.
193204
private def writeMapToBuffer(
194205
ctx: CodegenContext,
195206
input: String,
196207
index: String,
197208
keyType: DataType,
198209
valueType: DataType,
210+
valueContainsNull: Boolean,
199211
rowWriter: String): String = {
200212
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
201213
val tmpInput = ctx.freshName("tmpInput")
202214
val tmpCursor = ctx.freshName("tmpCursor")
203215
val previousCursor = ctx.freshName("previousCursor")
204216

205217
// Writes out unsafe map according to the format described in `UnsafeMapData`.
218+
val keyArray = writeArrayToBuffer(
219+
ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)
220+
val valueArray = writeArrayToBuffer(
221+
ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter)
222+
206223
s"""
207224
|final MapData $tmpInput = $input;
208225
|if ($tmpInput instanceof UnsafeMapData) {
@@ -219,15 +236,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
219236
| // Remember the current cursor so that we can write numBytes of key array later.
220237
| final int $tmpCursor = $rowWriter.cursor();
221238
|
222-
| ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
239+
| $keyArray
223240
|
224241
| // Write the numBytes of key array into the first 8 bytes.
225242
| Platform.putLong(
226243
| $rowWriter.getBuffer(),
227244
| $tmpCursor - 8,
228245
| $rowWriter.cursor() - $tmpCursor);
229246
|
230-
| ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)}
247+
| $valueArray
231248
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
232249
|}
233250
""".stripMargin
@@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
240257
dt: DataType,
241258
writer: String): String = dt match {
242259
case t: StructType =>
243-
writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
260+
writeStructToBuffer(
261+
ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer)
244262

245-
case ArrayType(et, _) =>
263+
case ArrayType(et, en) =>
246264
val previousCursor = ctx.freshName("previousCursor")
247265
s"""
248266
|// Remember the current cursor so that we can calculate how many bytes are
249267
|// written later.
250268
|final int $previousCursor = $writer.cursor();
251-
|${writeArrayToBuffer(ctx, input, et, writer)}
269+
|${writeArrayToBuffer(ctx, input, et, en, writer)}
252270
|$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
253271
""".stripMargin
254272

255-
case MapType(kt, vt, _) =>
256-
writeMapToBuffer(ctx, input, index, kt, vt, writer)
273+
case MapType(kt, vt, vn) =>
274+
writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)
257275

258276
case DecimalType.Fixed(precision, scale) =>
259277
s"$writer.write($index, $input, $precision, $scale);"
@@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
268286
expressions: Seq[Expression],
269287
useSubexprElimination: Boolean = false): ExprCode = {
270288
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
271-
val exprTypes = expressions.map(_.dataType)
289+
val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
272290

273-
val numVarLenFields = exprTypes.count {
274-
case dt if UnsafeRow.isFixedLength(dt) => false
291+
val numVarLenFields = exprSchemas.count {
292+
case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
275293
// TODO: consider large decimal and interval type
276-
case _ => true
277294
}
278295

279296
val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
284301
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
285302

286303
val writeExpressions = writeExpressionsToBuffer(
287-
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
304+
ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
288305

289306
val code =
290307
code"""

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
694694
|""".stripMargin
695695
val jsonSchema = new StructType()
696696
.add("a", LongType, nullable = false)
697-
.add("b", StringType, nullable = false)
697+
.add("b", StringType, nullable = !forceJsonNullableSchema)
698698
.add("c", StringType, nullable = false)
699699
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
700700
val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.BoundReference
23-
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
24-
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
23+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
24+
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2626

2727
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
@@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite {
3333
assert(!result.isNullAt(0))
3434
assert(result.getStruct(0, 1).isNullAt(0))
3535
}
36+
37+
test("Test unsafe projection for array/map/struct") {
38+
val dataType1 = ArrayType(StringType, false)
39+
val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil
40+
val projection1 = GenerateUnsafeProjection.generate(exprs1)
41+
val result1 = projection1.apply(AlwaysNonNull)
42+
assert(!result1.isNullAt(0))
43+
assert(!result1.getArray(0).isNullAt(0))
44+
assert(!result1.getArray(0).isNullAt(1))
45+
assert(!result1.getArray(0).isNullAt(2))
46+
47+
val dataType2 = MapType(StringType, StringType, false)
48+
val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil
49+
val projection2 = GenerateUnsafeProjection.generate(exprs2)
50+
val result2 = projection2.apply(AlwaysNonNull)
51+
assert(!result2.isNullAt(0))
52+
assert(!result2.getMap(0).keyArray.isNullAt(0))
53+
assert(!result2.getMap(0).keyArray.isNullAt(1))
54+
assert(!result2.getMap(0).keyArray.isNullAt(2))
55+
assert(!result2.getMap(0).valueArray.isNullAt(0))
56+
assert(!result2.getMap(0).valueArray.isNullAt(1))
57+
assert(!result2.getMap(0).valueArray.isNullAt(2))
58+
59+
val dataType3 = (new StructType)
60+
.add("a", StringType, nullable = false)
61+
.add("b", StringType, nullable = false)
62+
.add("c", StringType, nullable = false)
63+
val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil
64+
val projection3 = GenerateUnsafeProjection.generate(exprs3)
65+
val result3 = projection3.apply(InternalRow(AlwaysNonNull))
66+
assert(!result3.isNullAt(0))
67+
assert(!result3.getStruct(0, 1).isNullAt(0))
68+
assert(!result3.getStruct(0, 2).isNullAt(0))
69+
assert(!result3.getStruct(0, 3).isNullAt(0))
70+
}
3671
}
3772

3873
object AlwaysNull extends InternalRow {
@@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow {
5994
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
6095
private def notSupported: Nothing = throw new UnsupportedOperationException
6196
}
97+
98+
object AlwaysNonNull extends InternalRow {
99+
private def stringToUTF8Array(stringArray: Array[String]): ArrayData = {
100+
val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray
101+
ArrayData.toArrayData(utf8Array)
102+
}
103+
override def numFields: Int = 1
104+
override def setNullAt(i: Int): Unit = {}
105+
override def copy(): InternalRow = this
106+
override def anyNull: Boolean = notSupported
107+
override def isNullAt(ordinal: Int): Boolean = notSupported
108+
override def update(i: Int, value: Any): Unit = notSupported
109+
override def getBoolean(ordinal: Int): Boolean = notSupported
110+
override def getByte(ordinal: Int): Byte = notSupported
111+
override def getShort(ordinal: Int): Short = notSupported
112+
override def getInt(ordinal: Int): Int = notSupported
113+
override def getLong(ordinal: Int): Long = notSupported
114+
override def getFloat(ordinal: Int): Float = notSupported
115+
override def getDouble(ordinal: Int): Double = notSupported
116+
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
117+
override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test")
118+
override def getBinary(ordinal: Int): Array[Byte] = notSupported
119+
override def getInterval(ordinal: Int): CalendarInterval = notSupported
120+
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
121+
override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3"))
122+
val keyArray = stringToUTF8Array(Array("1", "2", "3"))
123+
val valueArray = stringToUTF8Array(Array("a", "b", "c"))
124+
override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray)
125+
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
126+
private def notSupported: Nothing = throw new UnsupportedOperationException
127+
128+
}

0 commit comments

Comments
 (0)