@@ -33,7 +33,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri
3333import org .apache .avro .generic .GenericData .{EnumSymbol , Fixed }
3434import org .apache .commons .io .FileUtils
3535
36- import org .apache .spark .{SPARK_VERSION_SHORT , SparkConf , SparkException , SparkThrowable , SparkUpgradeException }
36+ import org .apache .spark .{SPARK_VERSION_SHORT , SparkConf , SparkException , SparkRuntimeException , SparkThrowable , SparkUpgradeException }
3737import org .apache .spark .TestUtils .assertExceptionMsg
3838import org .apache .spark .sql ._
3939import org .apache .spark .sql .TestingUDT .IntervalData
@@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
4545import org .apache .spark .sql .execution .{FormattedMode , SparkPlan }
4646import org .apache .spark .sql .execution .datasources .{CommonFileDataSourceSuite , DataSource , FilePartition }
4747import org .apache .spark .sql .execution .datasources .v2 .BatchScanExec
48- import org .apache .spark .sql .functions .col
48+ import org .apache .spark .sql .functions ._
4949import org .apache .spark .sql .internal .LegacyBehaviorPolicy
5050import org .apache .spark .sql .internal .LegacyBehaviorPolicy ._
5151import org .apache .spark .sql .internal .SQLConf
@@ -1345,6 +1345,49 @@ abstract class AvroSuite
13451345 }
13461346 }
13471347
1348+ test(" to_avro nested struct schema nullability mismatch" ) {
1349+ Seq ((true , false ), (false , true )).foreach {
1350+ case (innerNull, outerNull) =>
1351+ val innerSchema = StructType (Seq (StructField (" field1" , IntegerType , innerNull)))
1352+ val outerSchema = StructType (Seq (StructField (" innerStruct" , innerSchema, outerNull)))
1353+ val nestedSchema = StructType (Seq (StructField (" outerStruct" , outerSchema, false )))
1354+
1355+ val rowWithNull = if (innerNull) Row (Row (null )) else Row (null )
1356+ val data = Seq (Row (Row (Row (1 ))), Row (rowWithNull), Row (Row (Row (3 ))))
1357+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema)
1358+
1359+ val avroTypeStruct = s """ {
1360+ | "type": "record",
1361+ | "name": "outerStruct",
1362+ | "fields": [
1363+ | {
1364+ | "name": "innerStruct",
1365+ | "type": {
1366+ | "type": "record",
1367+ | "name": "innerStruct",
1368+ | "fields": [
1369+ | {"name": "field1", "type": "int"}
1370+ | ]
1371+ | }
1372+ | }
1373+ | ]
1374+ |}
1375+ """ .stripMargin // nullability mismatch for innerStruct
1376+
1377+ val expectedErrorName = if (outerNull) " `innerStruct`" else " `field1`"
1378+ val expectedErrorSchema = if (outerNull) " \" STRUCT<field1: INT NOT NULL>\" " else " \" INT\" "
1379+
1380+ checkError(
1381+ exception = intercept[SparkRuntimeException ] {
1382+ df.select(to_avro($" outerStruct" , avroTypeStruct)).collect()
1383+ },
1384+ condition = " AVRO_CANNOT_WRITE_NULL_FIELD" ,
1385+ parameters = Map (
1386+ " name" -> expectedErrorName,
1387+ " schema" -> expectedErrorSchema))
1388+ }
1389+ }
1390+
13481391 test(" support user provided avro schema for writing nullable fixed type" ) {
13491392 withTempPath { tempDir =>
13501393 val avroSchema =
0 commit comments