Skip to content

Commit 128865c

Browse files
committed
add tests for nested struct
1 parent 8fbda34 commit 128865c

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri
3333
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
3434
import org.apache.commons.io.FileUtils
3535

36-
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkThrowable, SparkUpgradeException}
36+
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkRuntimeException, SparkThrowable, SparkUpgradeException}
3737
import org.apache.spark.TestUtils.assertExceptionMsg
3838
import org.apache.spark.sql._
3939
import org.apache.spark.sql.TestingUDT.IntervalData
@@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
4545
import org.apache.spark.sql.execution.{FormattedMode, SparkPlan}
4646
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition}
4747
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
48-
import org.apache.spark.sql.functions.col
48+
import org.apache.spark.sql.functions._
4949
import org.apache.spark.sql.internal.LegacyBehaviorPolicy
5050
import org.apache.spark.sql.internal.LegacyBehaviorPolicy._
5151
import org.apache.spark.sql.internal.SQLConf
@@ -1345,6 +1345,49 @@ abstract class AvroSuite
13451345
}
13461346
}
13471347

1348+
test("to_avro nested struct schema nullability mismatch") {
1349+
Seq((true, false), (false, true)).foreach {
1350+
case (innerNull, outerNull) =>
1351+
val innerSchema = StructType(Seq(StructField("field1", IntegerType, innerNull)))
1352+
val outerSchema = StructType(Seq(StructField("innerStruct", innerSchema, outerNull)))
1353+
val nestedSchema = StructType(Seq(StructField("outerStruct", outerSchema, false)))
1354+
1355+
val rowWithNull = if (innerNull) Row(Row(null)) else Row(null)
1356+
val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3))))
1357+
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema)
1358+
1359+
val avroTypeStruct = s"""{
1360+
| "type": "record",
1361+
| "name": "outerStruct",
1362+
| "fields": [
1363+
| {
1364+
| "name": "innerStruct",
1365+
| "type": {
1366+
| "type": "record",
1367+
| "name": "innerStruct",
1368+
| "fields": [
1369+
| {"name": "field1", "type": "int"}
1370+
| ]
1371+
| }
1372+
| }
1373+
| ]
1374+
|}
1375+
""".stripMargin // nullability mismatch for innerStruct
1376+
1377+
val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`"
1378+
val expectedErrorSchema = if (outerNull) "\"STRUCT<field1: INT NOT NULL>\"" else "\"INT\""
1379+
1380+
checkError(
1381+
exception = intercept[SparkRuntimeException] {
1382+
df.select(to_avro($"outerStruct", avroTypeStruct)).collect()
1383+
},
1384+
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
1385+
parameters = Map(
1386+
"name" -> expectedErrorName,
1387+
"schema" -> expectedErrorSchema))
1388+
}
1389+
}
1390+
13481391
test("support user provided avro schema for writing nullable fixed type") {
13491392
withTempPath { tempDir =>
13501393
val avroSchema =

sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,19 @@ private[sql] class AvroSerializer(
285285

286286
val numFields = catalystStruct.length
287287
val avroFields = avroStruct.getFields()
288+
val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable)
288289
row: InternalRow =>
289290
val result = new Record(avroStruct)
290291
var i = 0
291292
while (i < numFields) {
292293
if (row.isNullAt(i)) {
293-
val avroField = avroFields.get(i)
294-
if (!avroField.schema().isNullable) {
294+
if (!isSchemaNullable(i)) {
295295
throw new SparkRuntimeException(
296296
errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
297297
messageParameters = Map(
298-
"name" -> toSQLId(avroField.name),
299-
"schema" -> toSQLType(SchemaConverters.toSqlType(avroField.schema()).dataType)))
298+
"name" -> toSQLId(avroFields.get(i).name),
299+
"schema" -> toSQLType(SchemaConverters.toSqlType(
300+
avroFields.get(i).schema()).dataType)))
300301
}
301302
result.put(avroIndices(i), null)
302303
} else {

0 commit comments

Comments
 (0)