Skip to content

Commit 0a9172a

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization
## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido <[email protected]> Closes apache#20976 from mgaido91/SPARK-23835.
1 parent 30ffb53 commit 0a9172a

File tree

5 files changed

+27
-12
lines changed

5 files changed

+27
-12
lines changed

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
7979
val reader = createKafkaReader(topic)
8080
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
8181
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
82-
.as[(Int, Int)]
82+
.as[(Option[Int], Int)]
8383
.map(_._2)
8484

8585
try {
@@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
119119
val reader = createKafkaReader(topic)
120120
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
121121
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
122-
.as[(Int, Int)]
122+
.as[(Option[Int], Int)]
123123
.map(_._2)
124124

125125
try {
@@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
167167
val reader = createKafkaReader(topic)
168168
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
169169
.selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
170-
.as[(Int, Int)]
170+
.as[(Option[Int], Int)]
171171
.map(_._2)
172172

173173
try {

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext {
138138
val reader = createKafkaReader(topic)
139139
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
140140
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
141-
.as[(Int, Int)]
141+
.as[(Option[Int], Int)]
142142
.map(_._2)
143143

144144
try {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection {
382382
val clsName = getClassNameFromType(fieldType)
383383
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
384384
// For tuples, we based grab the inner fields by ordinal instead of name.
385-
if (cls.getName startsWith "scala.Tuple") {
385+
val constructor = if (cls.getName startsWith "scala.Tuple") {
386386
deserializerFor(
387387
fieldType,
388388
Some(addToPathOrdinal(i, dataType, newTypePath)),
389389
newTypePath)
390390
} else {
391-
val constructor = deserializerFor(
391+
deserializerFor(
392392
fieldType,
393393
Some(addToPath(fieldName, dataType, newTypePath)),
394394
newTypePath)
395+
}
395396

396-
if (!nullable) {
397-
AssertNotNull(constructor, newTypePath)
398-
} else {
399-
constructor
400-
}
397+
if (!nullable) {
398+
AssertNotNull(constructor, newTypePath)
399+
} else {
400+
constructor
401401
}
402402
}
403403

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
24-
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
24+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast}
2525
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
2626
import org.apache.spark.sql.types._
2727
import org.apache.spark.unsafe.types.UTF8String
@@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite {
365365
StructField("_2", NullType, nullable = true))),
366366
nullable = true))
367367
}
368+
369+
test("SPARK-23835: add null check to non-nullable types in Tuples") {
370+
def numberOfCheckedArguments(deserializer: Expression): Int = {
371+
assert(deserializer.isInstanceOf[NewInstance])
372+
deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
373+
}
374+
assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
375+
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
376+
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
377+
}
368378
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
14531453
val group2 = cached.groupBy("x").agg(min(col("z")) as "value")
14541454
checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil)
14551455
}
1456+
1457+
test("SPARK-23835: null primitive data type should throw NullPointerException") {
1458+
val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS()
1459+
intercept[NullPointerException](ds.as[(Int, Int)].collect())
1460+
}
14561461
}
14571462

14581463
case class TestDataUnion(x: Int, y: Int, z: Int)

0 commit comments

Comments
 (0)