@@ -36,7 +36,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
36
36
case NullType => true
37
37
case t : AtomicType => true
38
38
case _ : CalendarIntervalType => true
39
- case t : StructType => t.toSeq. forall(field => canSupport(field.dataType))
39
+ case t : StructType => t.forall(field => canSupport(field.dataType))
40
40
case t : ArrayType if canSupport(t.elementType) => true
41
41
case MapType (kt, vt, _) if canSupport(kt) && canSupport(vt) => true
42
42
case udt : UserDefinedType [_] => canSupport(udt.sqlType)
@@ -49,25 +49,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
49
49
input : String ,
50
50
fieldTypes : Seq [DataType ],
51
51
bufferHolder : String ): String = {
52
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
53
+ val tmpInput = ctx.freshName(" tmpInput" )
52
54
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
53
- val javaType = ctx.javaType(dt)
54
- val isNullVar = ctx.freshName(" isNull" )
55
- val valueVar = ctx.freshName(" value" )
56
- val defaultValue = ctx.defaultValue(dt)
57
- val readValue = ctx.getValue(input, dt, i.toString)
58
- val code =
59
- s """
60
- boolean $isNullVar = $input.isNullAt( $i);
61
- $javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
62
- """
63
- ExprCode (code, isNullVar, valueVar)
55
+ ExprCode (" " , s " $tmpInput.isNullAt( $i) " , ctx.getValue(tmpInput, dt, i.toString))
64
56
}
65
57
66
58
s """
67
- if ( $input instanceof UnsafeRow) {
68
- ${writeUnsafeData(ctx, s " ((UnsafeRow) $input) " , bufferHolder)}
59
+ final InternalRow $tmpInput = $input;
60
+ if ( $tmpInput instanceof UnsafeRow) {
61
+ ${writeUnsafeData(ctx, s " ((UnsafeRow) $tmpInput) " , bufferHolder)}
69
62
} else {
70
- ${writeExpressionsToBuffer(ctx, input , fieldEvals, fieldTypes, bufferHolder)}
63
+ ${writeExpressionsToBuffer(ctx, tmpInput , fieldEvals, fieldTypes, bufferHolder)}
71
64
}
72
65
"""
73
66
}
@@ -167,9 +160,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
167
160
}
168
161
}
169
162
163
+ val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null )) {
164
+ // TODO: support whole stage codegen
165
+ writeFields.mkString(" \n " )
166
+ } else {
167
+ assert(row != null , " the input row name cannot be null when generating code to write it." )
168
+ ctx.splitExpressions(
169
+ expressions = writeFields,
170
+ funcName = " writeFields" ,
171
+ arguments = Seq (" InternalRow" -> row))
172
+ }
173
+
170
174
s """
171
175
$resetWriter
172
- ${ctx.splitExpressions(row, writeFields)}
176
+ $writeFieldsCode
173
177
""" .trim
174
178
}
175
179
@@ -179,13 +183,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
179
183
input : String ,
180
184
elementType : DataType ,
181
185
bufferHolder : String ): String = {
186
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
187
+ val tmpInput = ctx.freshName(" tmpInput" )
182
188
val arrayWriterClass = classOf [UnsafeArrayWriter ].getName
183
189
val arrayWriter = ctx.freshName(" arrayWriter" )
184
190
ctx.addMutableState(arrayWriterClass, arrayWriter,
185
191
s " $arrayWriter = new $arrayWriterClass(); " )
186
192
val numElements = ctx.freshName(" numElements" )
187
193
val index = ctx.freshName(" index" )
188
- val element = ctx.freshName(" element" )
189
194
190
195
val et = elementType match {
191
196
case udt : UserDefinedType [_] => udt.sqlType
@@ -201,6 +206,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
201
206
}
202
207
203
208
val tmpCursor = ctx.freshName(" tmpCursor" )
209
+ val element = ctx.getValue(tmpInput, et, index)
204
210
val writeElement = et match {
205
211
case t : StructType =>
206
212
s """
@@ -233,17 +239,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
233
239
234
240
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else " "
235
241
s """
236
- if ( $input instanceof UnsafeArrayData) {
237
- ${writeUnsafeData(ctx, s " ((UnsafeArrayData) $input) " , bufferHolder)}
242
+ final ArrayData $tmpInput = $input;
243
+ if ( $tmpInput instanceof UnsafeArrayData) {
244
+ ${writeUnsafeData(ctx, s " ((UnsafeArrayData) $tmpInput) " , bufferHolder)}
238
245
} else {
239
- final int $numElements = $input .numElements();
246
+ final int $numElements = $tmpInput .numElements();
240
247
$arrayWriter.initialize( $bufferHolder, $numElements, $elementOrOffsetSize);
241
248
242
249
for (int $index = 0; $index < $numElements; $index++) {
243
- if ( $input .isNullAt( $index)) {
250
+ if ( $tmpInput .isNullAt( $index)) {
244
251
$arrayWriter.setNull $primitiveTypeName( $index);
245
252
} else {
246
- final $jt $element = ${ctx.getValue(input, et, index)};
247
253
$writeElement
248
254
}
249
255
}
@@ -258,31 +264,28 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
258
264
keyType : DataType ,
259
265
valueType : DataType ,
260
266
bufferHolder : String ): String = {
261
- val keys = ctx.freshName( " keys " )
262
- val values = ctx.freshName(" values " )
267
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
268
+ val tmpInput = ctx.freshName(" tmpInput " )
263
269
val tmpCursor = ctx.freshName(" tmpCursor" )
264
270
265
-
266
271
// Writes out unsafe map according to the format described in `UnsafeMapData`.
267
272
s """
268
- if ( $input instanceof UnsafeMapData) {
269
- ${writeUnsafeData(ctx, s " ((UnsafeMapData) $input) " , bufferHolder)}
273
+ final MapData $tmpInput = $input;
274
+ if ( $tmpInput instanceof UnsafeMapData) {
275
+ ${writeUnsafeData(ctx, s " ((UnsafeMapData) $tmpInput) " , bufferHolder)}
270
276
} else {
271
- final ArrayData $keys = $input.keyArray();
272
- final ArrayData $values = $input.valueArray();
273
-
274
277
// preserve 8 bytes to write the key array numBytes later.
275
278
$bufferHolder.grow(8);
276
279
$bufferHolder.cursor += 8;
277
280
278
281
// Remember the current cursor so that we can write numBytes of key array later.
279
282
final int $tmpCursor = $bufferHolder.cursor;
280
283
281
- ${writeArrayToBuffer(ctx, keys , keyType, bufferHolder)}
284
+ ${writeArrayToBuffer(ctx, s " $tmpInput .keyArray() " , keyType, bufferHolder)}
282
285
// Write the numBytes of key array into the first 8 bytes.
283
286
Platform.putLong( $bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
284
287
285
- ${writeArrayToBuffer(ctx, values , valueType, bufferHolder)}
288
+ ${writeArrayToBuffer(ctx, s " $tmpInput .valueArray() " , valueType, bufferHolder)}
286
289
}
287
290
"""
288
291
}
0 commit comments