Skip to content

Commit 8d54bf7

Browse files
MaxGekkHyukjinKwon
authored andcommitted
[SPARK-26099][SQL] Verification of the corrupt column in from_csv/from_json
## What changes were proposed in this pull request? The corrupt column specified via JSON/CSV option *columnNameOfCorruptRecord* must have the `string` type and be `nullable`. This has been already checked in `DataFrameReader`.`csv`/`json` and in `Json`/`CsvFileFormat` but not in `from_json`/`from_csv`. The PR adds such checks inside functions as well. ## How was this patch tested? Added tests to `Json`/`CsvExpressionSuite` for checking type of the corrupt column. They don't check the `nullable` property because `schema` is forcibly casted to nullable. Closes apache#23070 from MaxGekk/verify-corrupt-column-csv-json. Authored-by: Maxim Gekk <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent ab2eafb commit 8d54bf7

File tree

8 files changed

+51
-33
lines changed

8 files changed

+51
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,20 @@ object ExprUtils {
6767
case _ =>
6868
throw new AnalysisException("Must use a map() function for options")
6969
}
70+
71+
/**
72+
* A convenient function for schema validation in datasources supporting
73+
* `columnNameOfCorruptRecord` as an option.
74+
*/
75+
def verifyColumnNameOfCorruptRecord(
76+
schema: StructType,
77+
columnNameOfCorruptRecord: String): Unit = {
78+
schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
79+
val f = schema(corruptFieldIndex)
80+
if (f.dataType != StringType || !f.nullable) {
81+
throw new AnalysisException(
82+
"The field for corrupt records must be string type and nullable")
83+
}
84+
}
85+
}
7086
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ case class CsvToStructs(
106106
throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " +
107107
s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.")
108108
}
109+
ExprUtils.verifyColumnNameOfCorruptRecord(
110+
nullableSchema,
111+
parsedOptions.columnNameOfCorruptRecord)
112+
109113
val actualSchema =
110114
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
111115
val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ case class JsonToStructs(
579579
}
580580
val (parserSchema, actualSchema) = nullableSchema match {
581581
case s: StructType =>
582+
ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
582583
(s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
583584
case other =>
584585
(StructType(StructField("value", other) :: Nil), other)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale}
2323
import org.scalatest.exceptions.TestFailedException
2424

2525
import org.apache.spark.SparkFunSuite
26+
import org.apache.spark.sql.AnalysisException
2627
import org.apache.spark.sql.catalyst.InternalRow
2728
import org.apache.spark.sql.catalyst.plans.PlanTestBase
2829
import org.apache.spark.sql.catalyst.util._
@@ -226,4 +227,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
226227
InternalRow(17836)) // number of days from 1970-01-01
227228
}
228229
}
230+
231+
test("verify corrupt column") {
232+
checkExceptionInExpression[AnalysisException](
233+
CsvToStructs(
234+
schema = StructType.fromDDL("i int, _unparsed boolean"),
235+
options = Map("columnNameOfCorruptRecord" -> "_unparsed"),
236+
child = Literal.create("a"),
237+
timeZoneId = gmtId),
238+
expectedErrMsg = "The field for corrupt records must be string type and nullable")
239+
}
229240
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale}
2323
import org.scalatest.exceptions.TestFailedException
2424

2525
import org.apache.spark.{SparkException, SparkFunSuite}
26+
import org.apache.spark.sql.AnalysisException
2627
import org.apache.spark.sql.catalyst.InternalRow
2728
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2829
import org.apache.spark.sql.catalyst.plans.PlanTestBase
@@ -754,4 +755,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
754755
InternalRow(17836)) // number of days from 1970-01-01
755756
}
756757
}
758+
759+
test("verify corrupt column") {
760+
checkExceptionInExpression[AnalysisException](
761+
JsonToStructs(
762+
schema = StructType.fromDDL("i int, _unparsed boolean"),
763+
options = Map("columnNameOfCorruptRecord" -> "_unparsed"),
764+
child = Literal.create("""{"i":"a"}"""),
765+
timeZoneId = gmtId),
766+
expectedErrMsg = "The field for corrupt records must be string type and nullable")
767+
}
757768
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD
2929
import org.apache.spark.internal.Logging
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
32+
import org.apache.spark.sql.catalyst.expressions.ExprUtils
3233
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
3334
import org.apache.spark.sql.catalyst.util.FailureSafeParser
3435
import org.apache.spark.sql.execution.command.DDLUtils
@@ -442,7 +443,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
442443
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
443444
}
444445

445-
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
446+
ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
446447
val actualSchema =
447448
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
448449

@@ -504,7 +505,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
504505
parsedOptions)
505506
}
506507

507-
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
508+
ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
508509
val actualSchema =
509510
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
510511

@@ -765,22 +766,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
765766
}
766767
}
767768

768-
/**
769-
* A convenient function for schema validation in datasources supporting
770-
* `columnNameOfCorruptRecord` as an option.
771-
*/
772-
private def verifyColumnNameOfCorruptRecord(
773-
schema: StructType,
774-
columnNameOfCorruptRecord: String): Unit = {
775-
schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
776-
val f = schema(corruptFieldIndex)
777-
if (f.dataType != StringType || !f.nullable) {
778-
throw new AnalysisException(
779-
"The field for corrupt records must be string type and nullable")
780-
}
781-
}
782-
}
783-
784769
///////////////////////////////////////////////////////////////////////////////////////
785770
// Builder pattern config options
786771
///////////////////////////////////////////////////////////////////////////////////////

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql.{AnalysisException, SparkSession}
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser}
30+
import org.apache.spark.sql.catalyst.expressions.ExprUtils
3031
import org.apache.spark.sql.catalyst.util.CompressionCodecs
3132
import org.apache.spark.sql.execution.datasources._
3233
import org.apache.spark.sql.sources._
@@ -110,13 +111,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
110111
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
111112

112113
// Check a field requirement for corrupt records here to throw an exception in a driver side
113-
dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
114-
val f = dataSchema(corruptFieldIndex)
115-
if (f.dataType != StringType || !f.nullable) {
116-
throw new AnalysisException(
117-
"The field for corrupt records must be string type and nullable")
118-
}
119-
}
114+
ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)
120115

121116
if (requiredSchema.length == 1 &&
122117
requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
2626
import org.apache.spark.internal.Logging
2727
import org.apache.spark.sql.{AnalysisException, SparkSession}
2828
import org.apache.spark.sql.catalyst.InternalRow
29-
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead}
29+
import org.apache.spark.sql.catalyst.expressions.ExprUtils
30+
import org.apache.spark.sql.catalyst.json._
3031
import org.apache.spark.sql.catalyst.util.CompressionCodecs
3132
import org.apache.spark.sql.execution.datasources._
3233
import org.apache.spark.sql.sources._
@@ -107,13 +108,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
107108
val actualSchema =
108109
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
109110
// Check a field requirement for corrupt records here to throw an exception in a driver side
110-
dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
111-
val f = dataSchema(corruptFieldIndex)
112-
if (f.dataType != StringType || !f.nullable) {
113-
throw new AnalysisException(
114-
"The field for corrupt records must be string type and nullable")
115-
}
116-
}
111+
ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)
117112

118113
if (requiredSchema.length == 1 &&
119114
requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {

0 commit comments

Comments
 (0)