Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@
},
"sqlState" : "42604"
},
"AVRO_CANNOT_WRITE_NULL_FIELD" : {
"message" : [
"Cannot write null value for field <name> defined as non-null Avro data type <dataType>.",
"To allow null value for this field, specify its avro schema as a union type with \"null\" using `avroSchema` option."
],
"sqlState" : "22004"
},
"AVRO_INCOMPATIBLE_READ_TYPE" : {
"message" : [
"Cannot convert Avro <avroPath> to SQL <sqlPath> because the original encoded data type is <avroType>, however you're trying to read the field as <sqlType>, which would lead to an incorrect answer.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ import java.util.UUID

import scala.jdk.CollectionConverters._

import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, SchemaFormatter}
import org.apache.avro.{Schema, SchemaBuilder, SchemaFormatter}
import org.apache.avro.Schema.{Field, Type}
import org.apache.avro.Schema.Type._
import org.apache.avro.file.{DataFileReader, DataFileWriter}
import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils

import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkRuntimeException, SparkThrowable, SparkUpgradeException}
import org.apache.spark.TestUtils.assertExceptionMsg
import org.apache.spark.sql._
import org.apache.spark.sql.TestingUDT.IntervalData
Expand All @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
import org.apache.spark.sql.execution.{FormattedMode, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.LegacyBehaviorPolicy
import org.apache.spark.sql.internal.LegacyBehaviorPolicy._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -100,6 +100,14 @@ abstract class AvroSuite
SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema)
}

private def getRootCause(ex: Throwable): Throwable = {
var rootCause = ex
while (rootCause.getCause != null) {
rootCause = rootCause.getCause
}
rootCause
}

// Check whether an Avro schema of union type is converted to SQL in an expected way, when the
// stable ID option is on.
//
Expand Down Expand Up @@ -1317,7 +1325,16 @@ abstract class AvroSuite
dfWithNull.write.format("avro")
.option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
}
assertExceptionMsg[AvroTypeException](e1, "value null is not a SuitEnumType")

val expectedDatatype = "{\"type\":\"enum\",\"name\":\"SuitEnumType\"," +
"\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"

checkError(
getRootCause(e1).asInstanceOf[SparkThrowable],
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
parameters = Map(
"name" -> "`Suit`",
"dataType" -> expectedDatatype))

// Writing df containing data not in the enum will throw an exception
val e2 = intercept[SparkException] {
Expand All @@ -1332,6 +1349,50 @@ abstract class AvroSuite
}
}

test("to_avro nested struct schema nullability mismatch") {
Seq((true, false), (false, true)).foreach {
case (innerNull, outerNull) =>
val innerSchema = StructType(Seq(StructField("field1", IntegerType, innerNull)))
val outerSchema = StructType(Seq(StructField("innerStruct", innerSchema, outerNull)))
val nestedSchema = StructType(Seq(StructField("outerStruct", outerSchema, false)))

val rowWithNull = if (innerNull) Row(Row(null)) else Row(null)
val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3))))
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), nestedSchema)

val avroTypeStruct = s"""{
| "type": "record",
| "name": "outerStruct",
| "fields": [
| {
| "name": "innerStruct",
| "type": {
| "type": "record",
| "name": "innerStruct",
| "fields": [
| {"name": "field1", "type": "int"}
| ]
| }
| }
| ]
|}
""".stripMargin // nullability mismatch for innerStruct

val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`"
val expectedErrorSchema = if (outerNull) "{\"type\":\"record\",\"name\":\"innerStruct\"" +
",\"fields\":[{\"name\":\"field1\",\"type\":\"int\"}]}" else "\"int\""

checkError(
exception = intercept[SparkRuntimeException] {
df.select(avro.functions.to_avro($"outerStruct", avroTypeStruct)).collect()
},
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
parameters = Map(
"name" -> expectedErrorName,
"dataType" -> expectedErrorSchema))
}
}

test("support user provided avro schema for writing nullable fixed type") {
withTempPath { tempDir =>
val avroSchema =
Expand Down Expand Up @@ -1517,9 +1578,12 @@ abstract class AvroSuite
.save(s"$tempDir/${UUID.randomUUID()}")
}
assert(ex.getCondition == "TASK_WRITE_FAILED")
assert(ex.getCause.isInstanceOf[java.lang.NullPointerException])
assert(ex.getCause.getMessage.contains(
"null value for (non-nullable) string at test_schema.Name"))
checkError(
ex.getCause.asInstanceOf[SparkThrowable],
condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
parameters = Map(
"name" -> "`Name`",
"dataType" -> "\"string\""))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
import org.apache.avro.util.Utf8

import org.apache.spark.SparkRuntimeException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -282,11 +284,20 @@ private[sql] class AvroSerializer(
}.toArray.unzip

val numFields = catalystStruct.length
val avroFields = avroStruct.getFields()
val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable)
row: InternalRow =>
val result = new Record(avroStruct)
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
if (!isSchemaNullable(i)) {
throw new SparkRuntimeException(
errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
messageParameters = Map(
"name" -> toSQLId(avroFields.get(i).name),
"dataType" -> avroFields.get(i).schema().toString))
}
result.put(avroIndices(i), null)
} else {
result.put(avroIndices(i), fieldConverters(i).apply(row, i))
Expand Down