Skip to content

Commit cd92f25

Browse files
committed
[SPARK-25746][SQL][FOLLOWUP] do not add unnecessary If expression
## What changes were proposed in this pull request? a followup of apache#22749. When we construct the new serializer in `ExpressionEncoder.tuple`, we don't need to add `if(isnull ...)` check for each field. They are either simple expressions that can propagate null correctly(e.g. `GetStructField(GetColumnByOrdinal(0, schema), index)`), or complex expression that already have the isnull check. ## How was this patch tested? existing tests Closes apache#22898 from cloud-fan/minor. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent cc82b9f commit cd92f25

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,59 +89,51 @@ object ExpressionEncoder {
8989
*/
9090
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
9191
// TODO: check if encoders length is more than 22 and throw exception for it.
92-
9392
encoders.foreach(_.assertUnresolved())
9493

95-
val schema = StructType(encoders.zipWithIndex.map {
96-
case (e, i) =>
97-
StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable)
98-
})
99-
10094
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
10195

96+
val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true)
10297
val serializers = encoders.zipWithIndex.map { case (enc, index) =>
10398
val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct
10499
assert(boundRefs.size == 1, "object serializer should have only one bound reference but " +
105100
s"there are ${boundRefs.size}")
106101

107102
val originalInputObject = boundRefs.head
108103
val newInputObject = Invoke(
109-
BoundReference(0, ObjectType(cls), nullable = true),
104+
newSerializerInput,
110105
s"_${index + 1}",
111106
originalInputObject.dataType,
112107
returnNullable = originalInputObject.nullable)
113108

114109
val newSerializer = enc.objSerializer.transformUp {
115-
case b: BoundReference => newInputObject
110+
case BoundReference(0, _, _) => newInputObject
116111
}
117112

118113
Alias(newSerializer, s"_${index + 1}")()
119114
}
115+
val newSerializer = CreateStruct(serializers)
120116

121-
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
117+
val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType)
118+
val deserializers = encoders.zipWithIndex.map { case (enc, index) =>
122119
val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct
123120
assert(getColExprs.size == 1, "object deserializer should have only one " +
124121
s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")
125122

126-
val input = GetStructField(GetColumnByOrdinal(0, schema), index)
127-
val newDeserializer = enc.objDeserializer.transformUp {
123+
val input = GetStructField(newDeserializerInput, index)
124+
enc.objDeserializer.transformUp {
128125
case GetColumnByOrdinal(0, _) => input
129126
}
130-
if (schema(index).nullable) {
131-
If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer)
132-
} else {
133-
newDeserializer
134-
}
135127
}
128+
val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false)
136129

137-
val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)),
138-
Literal.create(null, schema), CreateStruct(serializers))
139-
val deserializer =
140-
NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
130+
def nullSafe(input: Expression, result: Expression): Expression = {
131+
If(IsNull(input), Literal.create(null, result.dataType), result)
132+
}
141133

142134
new ExpressionEncoder[Any](
143-
serializer,
144-
deserializer,
135+
nullSafe(newSerializerInput, newSerializer),
136+
nullSafe(newDeserializerInput, newDeserializer),
145137
ClassTag(cls))
146138
}
147139

0 commit comments

Comments
 (0)