Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@
},
"sqlState" : "42604"
},
"AVRO_CANNOT_WRITE_NULL_FIELD" : {
"message" : [
"The record to be written to Avro contains null value for non-null field <name> with schema <schema> from Avro schema.",
"To allow writing this field, explicitly specifying the avroSchema as option."
],
"sqlState" : "22004"
},
"AVRO_INCOMPATIBLE_READ_TYPE" : {
"message" : [
"Cannot convert Avro <avroPath> to SQL <sqlPath> because the original encoded data type is <avroType>, however you're trying to read the field as <sqlType>, which would lead to an incorrect answer.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ import java.util.UUID

import scala.jdk.CollectionConverters._

import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, SchemaFormatter}
import org.apache.avro.{Schema, SchemaBuilder, SchemaFormatter}
import org.apache.avro.Schema.{Field, Type}
import org.apache.avro.Schema.Type._
import org.apache.avro.file.{DataFileReader, DataFileWriter}
import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils

import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkRuntimeException, SparkThrowable, SparkUpgradeException}
import org.apache.spark.TestUtils.assertExceptionMsg
import org.apache.spark.sql._
import org.apache.spark.sql.TestingUDT.IntervalData
Expand All @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
import org.apache.spark.sql.execution.{FormattedMode, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.LegacyBehaviorPolicy
import org.apache.spark.sql.internal.LegacyBehaviorPolicy._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -100,6 +100,14 @@ abstract class AvroSuite
SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema)
}

private def getRootCause(ex: Throwable): Throwable = {
var rootCause = ex
while (rootCause.getCause != null) {
rootCause = rootCause.getCause
}
rootCause
}

// Check whether an Avro schema of union type is converted to SQL in an expected way, when the
// stable ID option is on.
//
Expand Down Expand Up @@ -1317,7 +1325,13 @@ abstract class AvroSuite
dfWithNull.write.format("avro")
.option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
}
assertExceptionMsg[AvroTypeException](e1, "value null is not a SuitEnumType")

checkError(
getRootCause(e1).asInstanceOf[SparkThrowable],
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
parameters = Map(
"name" -> "`Suit`",
"schema" -> "\"STRING\""))

// Writing df containing data not in the enum will throw an exception
val e2 = intercept[SparkException] {
Expand All @@ -1332,6 +1346,49 @@ abstract class AvroSuite
}
}

test("to_avro nested struct schema nullability mismatch") {
Seq((true, false), (false, true)).foreach {
case (innerNull, outerNull) =>
val innerSchema = StructType(Seq(StructField("field1", IntegerType, innerNull)))
val outerSchema = StructType(Seq(StructField("innerStruct", innerSchema, outerNull)))
val nestedSchema = StructType(Seq(StructField("outerStruct", outerSchema, false)))

val rowWithNull = if (innerNull) Row(Row(null)) else Row(null)
val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3))))
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema)

val avroTypeStruct = s"""{
| "type": "record",
| "name": "outerStruct",
| "fields": [
| {
| "name": "innerStruct",
| "type": {
| "type": "record",
| "name": "innerStruct",
| "fields": [
| {"name": "field1", "type": "int"}
| ]
| }
| }
| ]
|}
""".stripMargin // nullability mismatch for innerStruct

val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`"
val expectedErrorSchema = if (outerNull) "\"STRUCT<field1: INT NOT NULL>\"" else "\"INT\""

checkError(
exception = intercept[SparkRuntimeException] {
df.select(avro.functions.to_avro($"outerStruct", avroTypeStruct)).collect()
},
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
parameters = Map(
"name" -> expectedErrorName,
"schema" -> expectedErrorSchema))
}
}

test("support user provided avro schema for writing nullable fixed type") {
withTempPath { tempDir =>
val avroSchema =
Expand Down Expand Up @@ -1517,9 +1574,12 @@ abstract class AvroSuite
.save(s"$tempDir/${UUID.randomUUID()}")
}
assert(ex.getCondition == "TASK_WRITE_FAILED")
assert(ex.getCause.isInstanceOf[java.lang.NullPointerException])
assert(ex.getCause.getMessage.contains(
"null value for (non-nullable) string at test_schema.Name"))
checkError(
ex.getCause.asInstanceOf[SparkThrowable],
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
parameters = Map(
"name" -> "`Name`",
"schema" -> "\"STRING\""))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
import org.apache.avro.util.Utf8

import org.apache.spark.SparkRuntimeException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -282,11 +284,21 @@ private[sql] class AvroSerializer(
}.toArray.unzip

val numFields = catalystStruct.length
val avroFields = avroStruct.getFields()
val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable)
row: InternalRow =>
val result = new Record(avroStruct)
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
if (!isSchemaNullable(i)) {
throw new SparkRuntimeException(
errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
messageParameters = Map(
"name" -> toSQLId(avroFields.get(i).name),
"schema" -> toSQLType(SchemaConverters.toSqlType(
avroFields.get(i).schema()).dataType)))
}
result.put(avroIndices(i), null)
} else {
result.put(avroIndices(i), fieldConverters(i).apply(row, i))
Expand Down
Loading