Skip to content

Commit c9667af

Browse files
HyukjinKwonMaxGekk
andcommitted
[SPARK-25672][SQL] schema_of_csv() - schema inference from an example
## What changes were proposed in this pull request? In the PR, I propose to add new function - *schema_of_csv()* which infers schema of CSV string literal. The result of the function is a string containing a schema in DDL format. For example: ```sql select schema_of_csv('1|abc', map('delimiter', '|')) ``` ``` struct<_c0:int,_c1:string> ``` ## How was this patch tested? Added new tests to `CsvFunctionsSuite`, `CsvExpressionsSuite` and SQL tests to `csv-functions.sql` Closes apache#22666 from MaxGekk/schema_of_csv-function. Lead-authored-by: hyukjinkwon <[email protected]> Co-authored-by: Maxim Gekk <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent c5ef477 commit c9667af

File tree

14 files changed

+262
-45
lines changed

14 files changed

+262
-45
lines changed

python/pyspark/sql/functions.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,6 +2364,33 @@ def schema_of_json(json, options={}):
23642364
return Column(jc)
23652365

23662366

2367+
@ignore_unicode_prefix
2368+
@since(3.0)
2369+
def schema_of_csv(csv, options={}):
2370+
"""
2371+
Parses a CSV string and infers its schema in DDL format.
2372+
2373+
:param col: a CSV string or a string literal containing a CSV string.
2374+
:param options: options to control parsing. accepts the same options as the CSV datasource
2375+
2376+
>>> df = spark.range(1)
2377+
>>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect()
2378+
[Row(csv=u'struct<_c0:int,_c1:string>')]
2379+
>>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()
2380+
[Row(csv=u'struct<_c0:int,_c1:string>')]
2381+
"""
2382+
if isinstance(csv, basestring):
2383+
col = _create_column_from_literal(csv)
2384+
elif isinstance(csv, Column):
2385+
col = _to_java_column(csv)
2386+
else:
2387+
raise TypeError("schema argument should be a column or string")
2388+
2389+
sc = SparkContext._active_spark_context
2390+
jc = sc._jvm.functions.schema_of_csv(col, options)
2391+
return Column(jc)
2392+
2393+
23672394
@since(1.5)
23682395
def size(col):
23692396
"""
@@ -2664,13 +2691,13 @@ def from_csv(col, schema, options={}):
26642691
:param schema: a string with schema in DDL format to use when parsing the CSV column.
26652692
:param options: options to control parsing. accepts the same options as the CSV datasource
26662693
2667-
>>> data = [(1, '1')]
2668-
>>> df = spark.createDataFrame(data, ("key", "value"))
2669-
>>> df.select(from_csv(df.value, "a INT").alias("csv")).collect()
2670-
[Row(csv=Row(a=1))]
2671-
>>> df = spark.createDataFrame(data, ("key", "value"))
2672-
>>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect()
2673-
[Row(csv=Row(a=1))]
2694+
>>> data = [("1,2,3",)]
2695+
>>> df = spark.createDataFrame(data, ("value",))
2696+
>>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect()
2697+
[Row(csv=Row(a=1, b=2, c=3))]
2698+
>>> value = data[0][0]
2699+
>>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()
2700+
[Row(csv=Row(_c0=1, _c1=2, _c2=3))]
26742701
"""
26752702

26762703
sc = SparkContext._active_spark_context

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ object FunctionRegistry {
526526
castAlias("string", StringType),
527527

528528
// csv
529-
expression[CsvToStructs]("from_csv")
529+
expression[CsvToStructs]("from_csv"),
530+
expression[SchemaOfCsv]("schema_of_csv")
530531
)
531532

532533
val builtin: SimpleFunctionRegistry = {
Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,18 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution.datasources.csv
18+
package org.apache.spark.sql.catalyst.csv
1919

2020
import java.math.BigDecimal
2121

22-
import scala.util.control.Exception._
22+
import scala.util.control.Exception.allCatch
2323

2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
26-
import org.apache.spark.sql.catalyst.csv.CSVOptions
2726
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2827
import org.apache.spark.sql.types._
2928

30-
private[csv] object CSVInferSchema {
29+
object CSVInferSchema {
3130

3231
/**
3332
* Similar to the JSON schema inference
@@ -44,13 +43,7 @@ private[csv] object CSVInferSchema {
4443
val rootTypes: Array[DataType] =
4544
tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
4645

47-
header.zip(rootTypes).map { case (thisHeader, rootType) =>
48-
val dType = rootType match {
49-
case _: NullType => StringType
50-
case other => other
51-
}
52-
StructField(thisHeader, dType, nullable = true)
53-
}
46+
toStructFields(rootTypes, header, options)
5447
} else {
5548
// By default fields are assumed to be StringType
5649
header.map(fieldName => StructField(fieldName, StringType, nullable = true))
@@ -59,7 +52,20 @@ private[csv] object CSVInferSchema {
5952
StructType(fields)
6053
}
6154

62-
private def inferRowType(options: CSVOptions)
55+
def toStructFields(
56+
fieldTypes: Array[DataType],
57+
header: Array[String],
58+
options: CSVOptions): Array[StructField] = {
59+
header.zip(fieldTypes).map { case (thisHeader, rootType) =>
60+
val dType = rootType match {
61+
case _: NullType => StringType
62+
case other => other
63+
}
64+
StructField(thisHeader, dType, nullable = true)
65+
}
66+
}
67+
68+
def inferRowType(options: CSVOptions)
6369
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
6470
var i = 0
6571
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,39 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
22-
import org.apache.spark.sql.types.{MapType, StringType, StructType}
22+
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
23+
import org.apache.spark.unsafe.types.UTF8String
2324

2425
object ExprUtils {
2526

26-
def evalSchemaExpr(exp: Expression): StructType = exp match {
27-
case Literal(s, StringType) => StructType.fromDDL(s.toString)
27+
def evalSchemaExpr(exp: Expression): StructType = {
28+
// Use `DataType.fromDDL` since the type string can be struct<...>.
29+
val dataType = exp match {
30+
case Literal(s, StringType) =>
31+
DataType.fromDDL(s.toString)
32+
case e @ SchemaOfCsv(_: Literal, _) =>
33+
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
34+
DataType.fromDDL(ddlSchema.toString)
35+
case e => throw new AnalysisException(
36+
"Schema should be specified in DDL format as a string literal or output of " +
37+
s"the schema_of_csv function instead of ${e.sql}")
38+
}
39+
40+
if (!dataType.isInstanceOf[StructType]) {
41+
throw new AnalysisException(
42+
s"Schema should be struct type but got ${dataType.sql}.")
43+
}
44+
dataType.asInstanceOf[StructType]
45+
}
46+
47+
def evalTypeExpr(exp: Expression): DataType = exp match {
48+
case Literal(s, StringType) => DataType.fromDDL(s.toString)
49+
case e @ SchemaOfJson(_: Literal, _) =>
50+
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
51+
DataType.fromDDL(ddlSchema.toString)
2852
case e => throw new AnalysisException(
29-
s"Schema should be specified in DDL format as a string literal instead of ${e.sql}")
53+
"Schema should be specified in DDL format as a string literal or output of " +
54+
s"the schema_of_json function instead of ${e.sql}")
3055
}
3156

3257
def convertToMapData(exp: Expression): Map[String, String] = exp match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import com.univocity.parsers.csv.CsvParser
21+
2022
import org.apache.spark.sql.AnalysisException
2123
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2225
import org.apache.spark.sql.catalyst.csv._
2326
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2427
import org.apache.spark.sql.catalyst.util._
@@ -120,3 +123,54 @@ case class CsvToStructs(
120123

121124
override def prettyName: String = "from_csv"
122125
}
126+
127+
/**
128+
* A function infers schema of CSV string.
129+
*/
130+
@ExpressionDescription(
131+
usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.",
132+
examples = """
133+
Examples:
134+
> SELECT _FUNC_('1,abc');
135+
struct<_c0:int,_c1:string>
136+
""",
137+
since = "3.0.0")
138+
case class SchemaOfCsv(
139+
child: Expression,
140+
options: Map[String, String])
141+
extends UnaryExpression with CodegenFallback {
142+
143+
def this(child: Expression) = this(child, Map.empty[String, String])
144+
145+
def this(child: Expression, options: Expression) = this(
146+
child = child,
147+
options = ExprUtils.convertToMapData(options))
148+
149+
override def dataType: DataType = StringType
150+
151+
override def nullable: Boolean = false
152+
153+
@transient
154+
private lazy val csv = child.eval().asInstanceOf[UTF8String]
155+
156+
override def checkInputDataTypes(): TypeCheckResult = child match {
157+
case Literal(s, StringType) if s != null => super.checkInputDataTypes()
158+
case _ => TypeCheckResult.TypeCheckFailure(
159+
s"The input csv should be a string literal and not null; however, got ${child.sql}.")
160+
}
161+
162+
override def eval(v: InternalRow): Any = {
163+
val parsedOptions = new CSVOptions(options, true, "UTC")
164+
val parser = new CsvParser(parsedOptions.asParserSettings)
165+
val row = parser.parseLine(csv.toString)
166+
assert(row != null, "Parsed CSV record should not be null.")
167+
168+
val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
169+
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
170+
val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row)
171+
val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions))
172+
UTF8String.fromString(st.catalogString)
173+
}
174+
175+
override def prettyName: String = "schema_of_csv"
176+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ case class JsonToStructs(
529529
// Used in `FunctionRegistry`
530530
def this(child: Expression, schema: Expression, options: Map[String, String]) =
531531
this(
532-
schema = JsonExprUtils.evalSchemaExpr(schema),
532+
schema = ExprUtils.evalTypeExpr(schema),
533533
options = options,
534534
child = child,
535535
timeZoneId = None)
@@ -538,7 +538,7 @@ case class JsonToStructs(
538538

539539
def this(child: Expression, schema: Expression, options: Expression) =
540540
this(
541-
schema = JsonExprUtils.evalSchemaExpr(schema),
541+
schema = ExprUtils.evalTypeExpr(schema),
542542
options = ExprUtils.convertToMapData(options),
543543
child = child,
544544
timeZoneId = None)
@@ -784,15 +784,3 @@ case class SchemaOfJson(
784784

785785
override def prettyName: String = "schema_of_json"
786786
}
787-
788-
object JsonExprUtils {
789-
def evalSchemaExpr(exp: Expression): DataType = exp match {
790-
case Literal(s, StringType) => DataType.fromDDL(s.toString)
791-
case e @ SchemaOfJson(_: Literal, _) =>
792-
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
793-
DataType.fromDDL(ddlSchema.toString)
794-
case e => throw new AnalysisException(
795-
"Schema should be specified in DDL format as a string literal" +
796-
s" or output of the schema_of_json function instead of ${e.sql}")
797-
}
798-
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution.datasources.csv
18+
package org.apache.spark.sql.catalyst.csv
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.csv.CSVOptions
2221
import org.apache.spark.sql.types._
2322

2423
class CSVInferSchemaSuite extends SparkFunSuite {
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution.datasources.csv
18+
package org.apache.spark.sql.catalyst.csv
1919

2020
import java.math.BigDecimal
2121

2222
import org.apache.spark.SparkFunSuite
23-
import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser}
2423
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2524
import org.apache.spark.sql.types._
2625
import org.apache.spark.unsafe.types.UTF8String

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
155155
}.getCause
156156
assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode"))
157157
}
158+
159+
test("infer schema of CSV strings") {
160+
checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>")
161+
}
162+
163+
test("infer schema of CSV strings by using options") {
164+
checkEvaluation(
165+
new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")),
166+
"struct<_c0:int,_c1:string>")
167+
}
158168
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging
3434
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3535
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3636
import org.apache.spark.sql.catalyst.InternalRow
37-
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
37+
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser}
3838
import org.apache.spark.sql.execution.datasources._
3939
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
4040
import org.apache.spark.sql.types.StructType

0 commit comments

Comments
 (0)