@@ -32,6 +32,8 @@ import org.apache.spark.sql.types._
32
32
*/
33
33
object GenerateUnsafeProjection extends CodeGenerator [Seq [Expression ], UnsafeProjection ] {
34
34
35
+ case class Schema (dataType : DataType , nullable : Boolean )
36
+
35
37
/** Returns true iff we support this data type. */
36
38
def canSupport (dataType : DataType ): Boolean = UserDefinedType .sqlType(dataType) match {
37
39
case NullType => true
@@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
43
45
case _ => false
44
46
}
45
47
46
- // TODO: if the nullability of field is correct, we can use it to save null check.
47
48
private def writeStructToBuffer (
48
49
ctx : CodegenContext ,
49
50
input : String ,
50
51
index : String ,
51
- fieldTypes : Seq [DataType ],
52
+ schemas : Seq [Schema ],
52
53
rowWriter : String ): String = {
53
54
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
54
55
val tmpInput = ctx.freshName(" tmpInput" )
55
- val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
56
- ExprCode (
57
- JavaCode .isNullExpression(s " $tmpInput.isNullAt( $i) " ),
58
- JavaCode .expression(CodeGenerator .getValue(tmpInput, dt, i.toString), dt))
56
+ val fieldEvals = schemas.zipWithIndex.map { case (Schema (dt, nullable), i) =>
57
+ val isNull = if (nullable) {
58
+ JavaCode .isNullExpression(s " $tmpInput.isNullAt( $i) " )
59
+ } else {
60
+ FalseLiteral
61
+ }
62
+ ExprCode (isNull, JavaCode .expression(CodeGenerator .getValue(tmpInput, dt, i.toString), dt))
59
63
}
60
64
61
65
val rowWriterClass = classOf [UnsafeRowWriter ].getName
@@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
70
74
| // Remember the current cursor so that we can calculate how many bytes are
71
75
| // written later.
72
76
| final int $previousCursor = $rowWriter.cursor();
73
- | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes , structRowWriter)}
77
+ | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas , structRowWriter)}
74
78
| $rowWriter.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
75
79
|}
76
80
""" .stripMargin
@@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
80
84
ctx : CodegenContext ,
81
85
row : String ,
82
86
inputs : Seq [ExprCode ],
83
- inputTypes : Seq [DataType ],
87
+ schemas : Seq [Schema ],
84
88
rowWriter : String ,
85
89
isTopLevel : Boolean = false ): String = {
86
90
val resetWriter = if (isTopLevel) {
@@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
98
102
s " $rowWriter.resetRowWriter(); "
99
103
}
100
104
101
- val writeFields = inputs.zip(inputTypes ).zipWithIndex.map {
102
- case ((input, dataType), index) =>
105
+ val writeFields = inputs.zip(schemas ).zipWithIndex.map {
106
+ case ((input, Schema ( dataType, nullable) ), index) =>
103
107
val dt = UserDefinedType .sqlType(dataType)
104
108
105
109
val setNull = dt match {
@@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
110
114
}
111
115
112
116
val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
113
- if (input.isNull == FalseLiteral ) {
117
+ if (! nullable ) {
114
118
s """
115
119
| ${input.code}
116
120
| ${writeField.trim}
@@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
143
147
""" .stripMargin
144
148
}
145
149
146
- // TODO: if the nullability of array element is correct, we can use it to save null check.
147
150
private def writeArrayToBuffer (
148
151
ctx : CodegenContext ,
149
152
input : String ,
150
153
elementType : DataType ,
154
+ containsNull : Boolean ,
151
155
rowWriter : String ): String = {
152
156
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
153
157
val tmpInput = ctx.freshName(" tmpInput" )
@@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
170
174
171
175
val element = CodeGenerator .getValue(tmpInput, et, index)
172
176
177
+ val elementAssignment = if (containsNull) {
178
+ s """
179
+ |if ( $tmpInput.isNullAt( $index)) {
180
+ | $arrayWriter.setNull ${elementOrOffsetSize}Bytes( $index);
181
+ |} else {
182
+ | ${writeElement(ctx, element, index, et, arrayWriter)}
183
+ |}
184
+ """ .stripMargin
185
+ } else {
186
+ writeElement(ctx, element, index, et, arrayWriter)
187
+ }
188
+
173
189
s """
174
190
|final ArrayData $tmpInput = $input;
175
191
|if ( $tmpInput instanceof UnsafeArrayData) {
@@ -179,30 +195,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
179
195
| $arrayWriter.initialize( $numElements);
180
196
|
181
197
| for (int $index = 0; $index < $numElements; $index++) {
182
- | if ( $tmpInput.isNullAt( $index)) {
183
- | $arrayWriter.setNull ${elementOrOffsetSize}Bytes( $index);
184
- | } else {
185
- | ${writeElement(ctx, element, index, et, arrayWriter)}
186
- | }
198
+ | $elementAssignment
187
199
| }
188
200
|}
189
201
""" .stripMargin
190
202
}
191
203
192
- // TODO: if the nullability of value element is correct, we can use it to save null check.
193
204
private def writeMapToBuffer (
194
205
ctx : CodegenContext ,
195
206
input : String ,
196
207
index : String ,
197
208
keyType : DataType ,
198
209
valueType : DataType ,
210
+ valueContainsNull : Boolean ,
199
211
rowWriter : String ): String = {
200
212
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
201
213
val tmpInput = ctx.freshName(" tmpInput" )
202
214
val tmpCursor = ctx.freshName(" tmpCursor" )
203
215
val previousCursor = ctx.freshName(" previousCursor" )
204
216
205
217
// Writes out unsafe map according to the format described in `UnsafeMapData`.
218
+ val keyArray = writeArrayToBuffer(
219
+ ctx, s " $tmpInput.keyArray() " , keyType, false , rowWriter)
220
+ val valueArray = writeArrayToBuffer(
221
+ ctx, s " $tmpInput.valueArray() " , valueType, valueContainsNull, rowWriter)
222
+
206
223
s """
207
224
|final MapData $tmpInput = $input;
208
225
|if ( $tmpInput instanceof UnsafeMapData) {
@@ -219,15 +236,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
219
236
| // Remember the current cursor so that we can write numBytes of key array later.
220
237
| final int $tmpCursor = $rowWriter.cursor();
221
238
|
222
- | ${writeArrayToBuffer(ctx, s " $tmpInput . keyArray() " , keyType, rowWriter)}
239
+ | $keyArray
223
240
|
224
241
| // Write the numBytes of key array into the first 8 bytes.
225
242
| Platform.putLong(
226
243
| $rowWriter.getBuffer(),
227
244
| $tmpCursor - 8,
228
245
| $rowWriter.cursor() - $tmpCursor);
229
246
|
230
- | ${writeArrayToBuffer(ctx, s " $tmpInput . valueArray() " , valueType, rowWriter)}
247
+ | $valueArray
231
248
| $rowWriter.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
232
249
|}
233
250
""" .stripMargin
@@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
240
257
dt : DataType ,
241
258
writer : String ): String = dt match {
242
259
case t : StructType =>
243
- writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
260
+ writeStructToBuffer(
261
+ ctx, input, index, t.map(e => Schema (e.dataType, e.nullable)), writer)
244
262
245
- case ArrayType (et, _ ) =>
263
+ case ArrayType (et, en ) =>
246
264
val previousCursor = ctx.freshName(" previousCursor" )
247
265
s """
248
266
|// Remember the current cursor so that we can calculate how many bytes are
249
267
|// written later.
250
268
|final int $previousCursor = $writer.cursor();
251
- | ${writeArrayToBuffer(ctx, input, et, writer)}
269
+ | ${writeArrayToBuffer(ctx, input, et, en, writer)}
252
270
| $writer.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
253
271
""" .stripMargin
254
272
255
- case MapType (kt, vt, _ ) =>
256
- writeMapToBuffer(ctx, input, index, kt, vt, writer)
273
+ case MapType (kt, vt, vn ) =>
274
+ writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)
257
275
258
276
case DecimalType .Fixed (precision, scale) =>
259
277
s " $writer.write( $index, $input, $precision, $scale); "
@@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
268
286
expressions : Seq [Expression ],
269
287
useSubexprElimination : Boolean = false ): ExprCode = {
270
288
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
271
- val exprTypes = expressions.map(_ .dataType)
289
+ val exprSchemas = expressions.map(e => Schema (e .dataType, e.nullable) )
272
290
273
- val numVarLenFields = exprTypes .count {
274
- case dt if UnsafeRow .isFixedLength(dt) => false
291
+ val numVarLenFields = exprSchemas .count {
292
+ case Schema (dt, _) => ! UnsafeRow .isFixedLength(dt)
275
293
// TODO: consider large decimal and interval type
276
- case _ => true
277
294
}
278
295
279
296
val rowWriterClass = classOf [UnsafeRowWriter ].getName
@@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
284
301
val evalSubexpr = ctx.subexprFunctions.mkString(" \n " )
285
302
286
303
val writeExpressions = writeExpressionsToBuffer(
287
- ctx, ctx.INPUT_ROW , exprEvals, exprTypes , rowWriter, isTopLevel = true )
304
+ ctx, ctx.INPUT_ROW , exprEvals, exprSchemas , rowWriter, isTopLevel = true )
288
305
289
306
val code =
290
307
code """
0 commit comments