Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 17edfec

Browse files
wangyumgatorsmile
authored andcommitted
[SPARK-20427][SQL] Read JDBC table use custom schema
## What changes were proposed in this pull request? Auto generated Oracle schema some times not we expect: - `number(1)` auto mapped to BooleanType, some times it's not we expect, per [SPARK-20921](https://issues.apache.org/jira/browse/SPARK-20921). - `number` auto mapped to Decimal(38,10), It can't read big data, per [SPARK-20427](https://issues.apache.org/jira/browse/SPARK-20427). This PR fix this issue by custom schema as follows: ```scala val props = new Properties() props.put("customSchema", "ID decimal(38, 0), N1 int, N2 boolean") val dfRead = spark.read.schema(schema).jdbc(jdbcUrl, "tableWithCustomSchema", props) dfRead.show() ``` or ```sql CREATE TEMPORARY VIEW tableWithCustomSchema USING org.apache.spark.sql.jdbc OPTIONS (url '$jdbcUrl', dbTable 'tableWithCustomSchema', customSchema'ID decimal(38, 0), N1 int, N2 boolean') ``` ## How was this patch tested? unit tests Author: Yuming Wang <[email protected]> Closes apache#18266 from wangyum/SPARK-20427.
1 parent 8c7e19a commit 17edfec

File tree

10 files changed

+222
-10
lines changed

10 files changed

+222
-10
lines changed

docs/sql-programming-guide.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1328,7 +1328,14 @@ the following case-insensitive options:
13281328
<td>
13291329
The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: <code>"name CHAR(64), comments VARCHAR(1024)")</code>. The specified types should be valid spark sql data types. This option applies only to writing.
13301330
</td>
1331-
</tr>
1331+
</tr>
1332+
1333+
<tr>
1334+
<td><code>customSchema</code></td>
1335+
<td>
1336+
The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING"). The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading.
1337+
</td>
1338+
</tr>
13321339
</table>
13331340

13341341
<div class="codetabs">

examples/src/main/python/sql/datasource.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,16 @@ def jdbc_dataset_example(spark):
177177
.jdbc("jdbc:postgresql:dbserver", "schema.tablename",
178178
properties={"user": "username", "password": "password"})
179179

180+
# Specifying dataframe column data types on read
181+
jdbcDF3 = spark.read \
182+
.format("jdbc") \
183+
.option("url", "jdbc:postgresql:dbserver") \
184+
.option("dbtable", "schema.tablename") \
185+
.option("user", "username") \
186+
.option("password", "password") \
187+
.option("customSchema", "id DECIMAL(38, 0), name STRING") \
188+
.load()
189+
180190
# Saving data to a JDBC source
181191
jdbcDF.write \
182192
.format("jdbc") \

examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ object SQLDataSourceExample {
185185
connectionProperties.put("password", "password")
186186
val jdbcDF2 = spark.read
187187
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
188+
// Specifying the custom data types of the read schema
189+
connectionProperties.put("customSchema", "id DECIMAL(38, 0), name STRING")
190+
val jdbcDF3 = spark.read
191+
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
188192

189193
// Saving data to a JDBC source
190194
jdbcDF.write

external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp}
2121
import java.util.Properties
2222
import java.math.BigDecimal
2323

24-
import org.apache.spark.sql.Row
24+
import org.apache.spark.sql.{DataFrame, Row}
2525
import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec}
2626
import org.apache.spark.sql.test.SharedSQLContext
2727
import org.apache.spark.sql.types._
@@ -72,10 +72,17 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
7272
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
7373
conn.commit()
7474

75-
conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)")
76-
.executeUpdate()
77-
conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))")
78-
.executeUpdate()
75+
conn.prepareStatement(
76+
"CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)").executeUpdate()
77+
conn.prepareStatement(
78+
"INSERT INTO ts_with_timezone VALUES " +
79+
"(1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))").executeUpdate()
80+
conn.commit()
81+
82+
conn.prepareStatement(
83+
"CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate()
84+
conn.prepareStatement(
85+
"INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate()
7986
conn.commit()
8087

8188
sql(
@@ -104,7 +111,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
104111
}
105112

106113

107-
test("SPARK-16625 : Importing Oracle numeric types") {
114+
test("SPARK-16625 : Importing Oracle numeric types") {
108115
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties);
109116
val rows = df.collect()
110117
assert(rows.size == 1)
@@ -272,4 +279,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
272279
assert(row.getDate(0).equals(dateVal))
273280
assert(row.getTimestamp(1).equals(timestampVal))
274281
}
282+
283+
test("SPARK-20427/SPARK-20921: read table use custom schema by jdbc api") {
284+
// default will throw IllegalArgumentException
285+
val e = intercept[org.apache.spark.SparkException] {
286+
spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", new Properties()).collect()
287+
}
288+
assert(e.getMessage.contains(
289+
"requirement failed: Decimal precision 39 exceeds max precision 38"))
290+
291+
// custom schema can read data
292+
val props = new Properties()
293+
props.put("customSchema",
294+
s"ID DECIMAL(${DecimalType.MAX_PRECISION}, 0), N1 INT, N2 BOOLEAN")
295+
val dfRead = spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", props)
296+
297+
val rows = dfRead.collect()
298+
// verify the data type
299+
val types = rows(0).toSeq.map(x => x.getClass.toString)
300+
assert(types(0).equals("class java.math.BigDecimal"))
301+
assert(types(1).equals("class java.lang.Integer"))
302+
assert(types(2).equals("class java.lang.Boolean"))
303+
304+
// verify the value
305+
val values = rows(0)
306+
assert(values.getDecimal(0).equals(new java.math.BigDecimal("12312321321321312312312312123")))
307+
assert(values.getInt(1).equals(1))
308+
assert(values.getBoolean(2).equals(false))
309+
}
275310
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.sql.{Connection, DriverManager}
2121
import java.util.{Locale, Properties}
2222

2323
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
24+
import org.apache.spark.sql.types.StructType
2425

2526
/**
2627
* Options for the JDBC data source.
@@ -123,6 +124,8 @@ class JDBCOptions(
123124
// TODO: to reuse the existing partition parameters for those partition specific options
124125
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
125126
val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
127+
val customSchema = parameters.get(JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES)
128+
126129
val batchSize = {
127130
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
128131
require(size >= 1,
@@ -161,6 +164,7 @@ object JDBCOptions {
161164
val JDBC_TRUNCATE = newOption("truncate")
162165
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
163166
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
167+
val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
164168
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
165169
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
166170
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ object JDBCRDD extends Logging {
8080
* @return A Catalyst schema corresponding to columns in the given order.
8181
*/
8282
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
83-
val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
83+
val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
8484
new StructType(columns.map(name => fieldMap(name)))
8585
}
8686

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,14 @@ private[sql] case class JDBCRelation(
111111

112112
override val needConversion: Boolean = false
113113

114-
override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
114+
override val schema: StructType = {
115+
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
116+
jdbcOptions.customSchema match {
117+
case Some(customSchema) => JdbcUtils.getCustomSchema(
118+
tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
119+
case None => tableSchema
120+
}
121+
}
115122

116123
// Check if JDBCRDD.compileFilter can accept input filters
117124
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.executor.InputMetrics
2929
import org.apache.spark.internal.Logging
3030
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
3131
import org.apache.spark.sql.catalyst.InternalRow
32+
import org.apache.spark.sql.catalyst.analysis.Resolver
3233
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3334
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
3435
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -301,7 +302,6 @@ object JdbcUtils extends Logging {
301302
rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
302303
}
303304
val metadata = new MetadataBuilder()
304-
.putString("name", columnName)
305305
.putLong("scale", fieldScale)
306306
val columnType =
307307
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
@@ -767,6 +767,34 @@ object JdbcUtils extends Logging {
767767
if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
768768
}
769769

770+
/**
771+
* Parses the user specified customSchema option value to DataFrame schema,
772+
* and returns it if it's all columns are equals to default schema's.
773+
*/
774+
def getCustomSchema(
775+
tableSchema: StructType,
776+
customSchema: String,
777+
nameEquality: Resolver): StructType = {
778+
val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
779+
780+
SchemaUtils.checkColumnNameDuplication(
781+
userSchema.map(_.name), "in the customSchema option value", nameEquality)
782+
783+
val colNames = tableSchema.fieldNames.mkString(",")
784+
val errorMsg = s"Please provide all the columns, all columns are: $colNames"
785+
if (userSchema.size != tableSchema.size) {
786+
throw new AnalysisException(errorMsg)
787+
}
788+
789+
// This is resolved by names, only check the column names.
790+
userSchema.fieldNames.foreach { col =>
791+
tableSchema.find(f => nameEquality(f.name, col)).getOrElse {
792+
throw new AnalysisException(errorMsg)
793+
}
794+
}
795+
userSchema
796+
}
797+
770798
/**
771799
* Saves the RDD to the database in a single transaction.
772800
*/
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.jdbc
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.AnalysisException
22+
import org.apache.spark.sql.catalyst.parser.ParseException
23+
import org.apache.spark.sql.types._
24+
25+
class JdbcUtilsSuite extends SparkFunSuite {
26+
27+
val tableSchema = StructType(Seq(
28+
StructField("C1", StringType, false), StructField("C2", IntegerType, false)))
29+
val caseSensitive = org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
30+
val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
31+
32+
test("Parse user specified column types") {
33+
assert(
34+
JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseInsensitive) ===
35+
StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true))))
36+
assert(JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseSensitive) ===
37+
StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true))))
38+
assert(
39+
JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) ===
40+
StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true))))
41+
assert(JdbcUtils.getCustomSchema(
42+
tableSchema, "c1 DECIMAL(38, 0), C2 STRING", caseInsensitive) ===
43+
StructType(Seq(StructField("c1", DecimalType(38, 0), true),
44+
StructField("C2", StringType, true))))
45+
46+
// Throw AnalysisException
47+
val duplicate = intercept[AnalysisException]{
48+
JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) ===
49+
StructType(Seq(StructField("c1", DateType, true), StructField("c1", StringType, true)))
50+
}
51+
assert(duplicate.getMessage.contains(
52+
"Found duplicate column(s) in the customSchema option value"))
53+
54+
val allColumns = intercept[AnalysisException]{
55+
JdbcUtils.getCustomSchema(tableSchema, "C1 STRING", caseSensitive) ===
56+
StructType(Seq(StructField("C1", DateType, true)))
57+
}
58+
assert(allColumns.getMessage.contains("Please provide all the columns,"))
59+
60+
val caseSensitiveColumnNotFound = intercept[AnalysisException]{
61+
JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) ===
62+
StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true)))
63+
}
64+
assert(caseSensitiveColumnNotFound.getMessage.contains(
65+
"Please provide all the columns, all columns are: C1,C2;"))
66+
67+
val caseInsensitiveColumnNotFound = intercept[AnalysisException]{
68+
JdbcUtils.getCustomSchema(tableSchema, "c3 DATE, C2 STRING", caseInsensitive) ===
69+
StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
70+
}
71+
assert(caseInsensitiveColumnNotFound.getMessage.contains(
72+
"Please provide all the columns, all columns are: C1,C2;"))
73+
74+
// Throw ParseException
75+
val dataTypeNotSupported = intercept[ParseException]{
76+
JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) ===
77+
StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
78+
}
79+
assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported"))
80+
81+
val mismatchedInput = intercept[ParseException]{
82+
JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) ===
83+
StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
84+
}
85+
assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting"))
86+
}
87+
}

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,36 @@ class JDBCSuite extends SparkFunSuite
968968
assert(e2.contains("User specified schema not supported with `jdbc`"))
969969
}
970970

971+
test("jdbc API support custom schema") {
972+
val parts = Array[String]("THEID < 2", "THEID >= 2")
973+
val props = new Properties()
974+
props.put("customSchema", "NAME STRING, THEID BIGINT")
975+
val schema = StructType(Seq(
976+
StructField("NAME", StringType, true), StructField("THEID", LongType, true)))
977+
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props)
978+
assert(df.schema.size === 2)
979+
assert(df.schema === schema)
980+
assert(df.count() === 3)
981+
}
982+
983+
test("jdbc API custom schema DDL-like strings.") {
984+
withTempView("people_view") {
985+
sql(
986+
s"""
987+
|CREATE TEMPORARY VIEW people_view
988+
|USING org.apache.spark.sql.jdbc
989+
|OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass',
990+
|customSchema 'NAME STRING, THEID INT')
991+
""".stripMargin.replaceAll("\n", " "))
992+
val schema = StructType(
993+
Seq(StructField("NAME", StringType, true), StructField("THEID", IntegerType, true)))
994+
val df = sql("select * from people_view")
995+
assert(df.schema.size === 2)
996+
assert(df.schema === schema)
997+
assert(df.count() === 3)
998+
}
999+
}
1000+
9711001
test("SPARK-15648: teradataDialect StringType data mapping") {
9721002
val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
9731003
assert(teradataDialect.getJDBCType(StringType).

0 commit comments

Comments
 (0)