Skip to content

Commit d66667f

Browse files
fanyue-xiafanyue-xia
authored andcommitted
[SPARK-50906][SQL][SS] Add nullability check for if inputs of to_avro align with schema
### What changes were proposed in this pull request? Previously, we don't explicitly check when input of `to_avro` is `null` but the schema does not allow `null`. As a result, a NPE will be raised in this situation. This PR adds the check during serialization before writing to avro and raises user-facing error if above occurs. ### Why are the changes needed? It makes it easier for the user to understand and face the error. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #49590 from fanyue-xia/to_avro_improve_NPE. Lead-authored-by: fanyue-xia <chloexfy@gmail.com> Co-authored-by: fanyue-xia <chloe.xia@databircks.com> Signed-off-by: Gengliang Wang <gengliang@apache.org>
1 parent 8dbf1dd commit d66667f

File tree

3 files changed

+89
-7
lines changed

3 files changed

+89
-7
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@
117117
},
118118
"sqlState" : "42604"
119119
},
120+
"AVRO_CANNOT_WRITE_NULL_FIELD" : {
121+
"message" : [
122+
"Cannot write null value for field <name> defined as non-null Avro data type <dataType>.",
123+
"To allow null value for this field, specify its avro schema as a union type with \"null\" using `avroSchema` option."
124+
],
125+
"sqlState" : "22004"
126+
},
120127
"AVRO_INCOMPATIBLE_READ_TYPE" : {
121128
"message" : [
122129
"Cannot convert Avro <avroPath> to SQL <sqlPath> because the original encoded data type is <avroType>, however you're trying to read the field as <sqlType>, which would lead to an incorrect answer.",

connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ import java.util.UUID
2525

2626
import scala.jdk.CollectionConverters._
2727

28-
import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, SchemaFormatter}
28+
import org.apache.avro.{Schema, SchemaBuilder, SchemaFormatter}
2929
import org.apache.avro.Schema.{Field, Type}
3030
import org.apache.avro.Schema.Type._
3131
import org.apache.avro.file.{DataFileReader, DataFileWriter}
3232
import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
3333
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
3434
import org.apache.commons.io.FileUtils
3535

36-
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
36+
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkRuntimeException, SparkThrowable, SparkUpgradeException}
3737
import org.apache.spark.TestUtils.assertExceptionMsg
3838
import org.apache.spark.sql._
3939
import org.apache.spark.sql.TestingUDT.IntervalData
@@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
4545
import org.apache.spark.sql.execution.{FormattedMode, SparkPlan}
4646
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition}
4747
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
48-
import org.apache.spark.sql.functions.col
48+
import org.apache.spark.sql.functions._
4949
import org.apache.spark.sql.internal.LegacyBehaviorPolicy
5050
import org.apache.spark.sql.internal.LegacyBehaviorPolicy._
5151
import org.apache.spark.sql.internal.SQLConf
@@ -100,6 +100,14 @@ abstract class AvroSuite
100100
SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema)
101101
}
102102

103+
private def getRootCause(ex: Throwable): Throwable = {
104+
var rootCause = ex
105+
while (rootCause.getCause != null) {
106+
rootCause = rootCause.getCause
107+
}
108+
rootCause
109+
}
110+
103111
// Check whether an Avro schema of union type is converted to SQL in an expected way, when the
104112
// stable ID option is on.
105113
//
@@ -1317,7 +1325,16 @@ abstract class AvroSuite
13171325
dfWithNull.write.format("avro")
13181326
.option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
13191327
}
1320-
assertExceptionMsg[AvroTypeException](e1, "value null is not a SuitEnumType")
1328+
1329+
val expectedDatatype = "{\"type\":\"enum\",\"name\":\"SuitEnumType\"," +
1330+
"\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"
1331+
1332+
checkError(
1333+
getRootCause(e1).asInstanceOf[SparkThrowable],
1334+
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
1335+
parameters = Map(
1336+
"name" -> "`Suit`",
1337+
"dataType" -> expectedDatatype))
13211338

13221339
// Writing df containing data not in the enum will throw an exception
13231340
val e2 = intercept[SparkException] {
@@ -1332,6 +1349,50 @@ abstract class AvroSuite
13321349
}
13331350
}
13341351

1352+
test("to_avro nested struct schema nullability mismatch") {
1353+
Seq((true, false), (false, true)).foreach {
1354+
case (innerNull, outerNull) =>
1355+
val innerSchema = StructType(Seq(StructField("field1", IntegerType, innerNull)))
1356+
val outerSchema = StructType(Seq(StructField("innerStruct", innerSchema, outerNull)))
1357+
val nestedSchema = StructType(Seq(StructField("outerStruct", outerSchema, false)))
1358+
1359+
val rowWithNull = if (innerNull) Row(Row(null)) else Row(null)
1360+
val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3))))
1361+
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema)
1362+
1363+
val avroTypeStruct = s"""{
1364+
| "type": "record",
1365+
| "name": "outerStruct",
1366+
| "fields": [
1367+
| {
1368+
| "name": "innerStruct",
1369+
| "type": {
1370+
| "type": "record",
1371+
| "name": "innerStruct",
1372+
| "fields": [
1373+
| {"name": "field1", "type": "int"}
1374+
| ]
1375+
| }
1376+
| }
1377+
| ]
1378+
|}
1379+
""".stripMargin // nullability mismatch for innerStruct
1380+
1381+
val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`"
1382+
val expectedErrorSchema = if (outerNull) "{\"type\":\"record\",\"name\":\"innerStruct\"" +
1383+
",\"fields\":[{\"name\":\"field1\",\"type\":\"int\"}]}" else "\"int\""
1384+
1385+
checkError(
1386+
exception = intercept[SparkRuntimeException] {
1387+
df.select(avro.functions.to_avro($"outerStruct", avroTypeStruct)).collect()
1388+
},
1389+
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
1390+
parameters = Map(
1391+
"name" -> expectedErrorName,
1392+
"dataType" -> expectedErrorSchema))
1393+
}
1394+
}
1395+
13351396
test("support user provided avro schema for writing nullable fixed type") {
13361397
withTempPath { tempDir =>
13371398
val avroSchema =
@@ -1517,9 +1578,12 @@ abstract class AvroSuite
15171578
.save(s"$tempDir/${UUID.randomUUID()}")
15181579
}
15191580
assert(ex.getCondition == "TASK_WRITE_FAILED")
1520-
assert(ex.getCause.isInstanceOf[java.lang.NullPointerException])
1521-
assert(ex.getCause.getMessage.contains(
1522-
"null value for (non-nullable) string at test_schema.Name"))
1581+
checkError(
1582+
ex.getCause.asInstanceOf[SparkThrowable],
1583+
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
1584+
parameters = Map(
1585+
"name" -> "`Name`",
1586+
"dataType" -> "\"string\""))
15231587
}
15241588
}
15251589

sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ import org.apache.avro.Schema.Type._
2929
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
3030
import org.apache.avro.util.Utf8
3131

32+
import org.apache.spark.SparkRuntimeException
3233
import org.apache.spark.internal.Logging
3334
import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
3435
import org.apache.spark.sql.catalyst.InternalRow
3536
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
3637
import org.apache.spark.sql.catalyst.util.DateTimeUtils
38+
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
3739
import org.apache.spark.sql.execution.datasources.DataSourceUtils
3840
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
3941
import org.apache.spark.sql.types._
@@ -282,11 +284,20 @@ private[sql] class AvroSerializer(
282284
}.toArray.unzip
283285

284286
val numFields = catalystStruct.length
287+
val avroFields = avroStruct.getFields()
288+
val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable)
285289
row: InternalRow =>
286290
val result = new Record(avroStruct)
287291
var i = 0
288292
while (i < numFields) {
289293
if (row.isNullAt(i)) {
294+
if (!isSchemaNullable(i)) {
295+
throw new SparkRuntimeException(
296+
errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
297+
messageParameters = Map(
298+
"name" -> toSQLId(avroFields.get(i).name),
299+
"dataType" -> avroFields.get(i).schema().toString))
300+
}
290301
result.put(avroIndices(i), null)
291302
} else {
292303
result.put(avroIndices(i), fieldConverters(i).apply(row, i))

0 commit comments

Comments
 (0)