Skip to content

Commit f596ebe

Browse files
maropugatorsmile
authored andcommitted
[SPARK-24327][SQL] Verify and normalize a partition column name based on the JDBC resolved schema
## What changes were proposed in this pull request? This pr modified JDBC datasource code to verify and normalize a partition column based on the JDBC resolved schema before building `JDBCRelation`. Closes apache#20370 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro <[email protected]> Closes apache#21379 from maropu/SPARK-24327.
1 parent a5849ad commit f596ebe

File tree

4 files changed

+118
-17
lines changed

4 files changed

+118
-17
lines changed

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ private[spark] object Utils extends Logging {
100100
*/
101101
val DEFAULT_MAX_TO_STRING_FIELDS = 25
102102

103-
private def maxNumToStringFields = {
103+
private[spark] def maxNumToStringFields = {
104104
if (SparkEnv.get != null) {
105105
SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
106106
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer
2222
import org.apache.spark.Partition
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
25+
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
26+
import org.apache.spark.sql.catalyst.analysis._
2627
import org.apache.spark.sql.jdbc.JdbcDialects
2728
import org.apache.spark.sql.sources._
2829
import org.apache.spark.sql.types.StructType
30+
import org.apache.spark.util.Utils
2931

3032
/**
3133
* Instructions on how to partition the table among workers.
@@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging {
4850
* Null value predicate is added to the first partition where clause to include
4951
* the rows with null value for the partitions column.
5052
*
53+
* @param schema resolved schema of a JDBC table
5154
* @param partitioning partition information to generate the where clause for each partition
55+
* @param resolver function used to determine if two identifiers are equal
56+
* @param jdbcOptions JDBC options that contains url
5257
* @return an array of partitions with where clause for each partition
5358
*/
54-
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
59+
def columnPartition(
60+
schema: StructType,
61+
partitioning: JDBCPartitioningInfo,
62+
resolver: Resolver,
63+
jdbcOptions: JDBCOptions): Array[Partition] = {
5564
if (partitioning == null || partitioning.numPartitions <= 1 ||
5665
partitioning.lowerBound == partitioning.upperBound) {
5766
return Array[Partition](JDBCPartition(null, 0))
@@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging {
7887
// Overflow and silliness can happen if you subtract then divide.
7988
// Here we get a little roundoff, but that's (hopefully) OK.
8089
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
81-
val column = partitioning.column
90+
91+
val column = verifyAndGetNormalizedColumnName(
92+
schema, partitioning.column, resolver, jdbcOptions)
93+
8294
var i: Int = 0
8395
var currentValue: Long = lowerBound
8496
val ans = new ArrayBuffer[Partition]()
@@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging {
99111
}
100112
ans.toArray
101113
}
114+
115+
// Verify column name based on the JDBC resolved schema
116+
private def verifyAndGetNormalizedColumnName(
117+
schema: StructType,
118+
columnName: String,
119+
resolver: Resolver,
120+
jdbcOptions: JDBCOptions): String = {
121+
val dialect = JdbcDialects.get(jdbcOptions.url)
122+
schema.map(_.name).find { fieldName =>
123+
resolver(fieldName, columnName) ||
124+
resolver(dialect.quoteIdentifier(fieldName), columnName)
125+
}.map(dialect.quoteIdentifier).getOrElse {
126+
throw new AnalysisException(s"User-defined partition column $columnName not " +
127+
s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}")
128+
}
129+
}
130+
131+
/**
132+
* Takes a (schema, table) specification and returns the table's Catalyst schema.
133+
* If `customSchema` defined in the JDBC options, replaces the schema's dataType with the
134+
* custom schema's type.
135+
*
136+
* @param resolver function used to determine if two identifiers are equal
137+
* @param jdbcOptions JDBC options that contains url, table and other information.
138+
* @return resolved Catalyst schema of a JDBC table
139+
*/
140+
def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
141+
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
142+
jdbcOptions.customSchema match {
143+
case Some(customSchema) => JdbcUtils.getCustomSchema(
144+
tableSchema, customSchema, resolver)
145+
case None => tableSchema
146+
}
147+
}
148+
149+
/**
150+
* Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema.
151+
*/
152+
def apply(
153+
parts: Array[Partition],
154+
jdbcOptions: JDBCOptions)(
155+
sparkSession: SparkSession): JDBCRelation = {
156+
val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
157+
JDBCRelation(schema, parts, jdbcOptions)(sparkSession)
158+
}
102159
}
103160

104161
private[sql] case class JDBCRelation(
105-
parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
162+
override val schema: StructType,
163+
parts: Array[Partition],
164+
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
106165
extends BaseRelation
107166
with PrunedFilteredScan
108167
with InsertableRelation {
@@ -111,15 +170,6 @@ private[sql] case class JDBCRelation(
111170

112171
override val needConversion: Boolean = false
113172

114-
override val schema: StructType = {
115-
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
116-
jdbcOptions.customSchema match {
117-
case Some(customSchema) => JdbcUtils.getCustomSchema(
118-
tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
119-
case None => tableSchema
120-
}
121-
}
122-
123173
// Check if JDBCRDD.compileFilter can accept input filters
124174
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
125175
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider
4848
JDBCPartitioningInfo(
4949
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
5050
}
51-
val parts = JDBCRelation.columnPartition(partitionInfo)
52-
JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
51+
val resolver = sqlContext.conf.resolver
52+
val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
53+
val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions)
54+
JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
5355
}
5456

5557
override def createRelation(

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
3131
import org.apache.spark.sql.execution.DataSourceScanExec
3232
import org.apache.spark.sql.execution.command.ExplainCommand
3333
import org.apache.spark.sql.execution.datasources.LogicalRelation
34-
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
34+
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils}
3535
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
36+
import org.apache.spark.sql.internal.SQLConf
3637
import org.apache.spark.sql.sources._
3738
import org.apache.spark.sql.test.SharedSQLContext
3839
import org.apache.spark.sql.types._
@@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite
238239
|OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass')
239240
""".stripMargin.replaceAll("\n", " "))
240241

242+
conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " +
243+
"AS SELECT 1, 1")
244+
.executeUpdate()
245+
conn.commit()
246+
241247
// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
242248
}
243249

@@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite
12061212
}.getMessage
12071213
assert(errMsg.contains("Statement was canceled or the session timed out"))
12081214
}
1215+
1216+
test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") {
1217+
def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = {
1218+
val df = spark.read.format("jdbc")
1219+
.option("url", urlWithUserAndPass)
1220+
.option("dbtable", "TEST.PARTITION")
1221+
.option("partitionColumn", partColName)
1222+
.option("lowerBound", 1)
1223+
.option("upperBound", 4)
1224+
.option("numPartitions", 3)
1225+
.load()
1226+
1227+
val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName)
1228+
df.logicalPlan match {
1229+
case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
1230+
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
1231+
assert(whereClauses === Set(
1232+
s"$quotedPrtColName < 2 or $quotedPrtColName is null",
1233+
s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3",
1234+
s"$quotedPrtColName >= 3"))
1235+
}
1236+
}
1237+
1238+
testJdbcParitionColumn("THEID", "THEID")
1239+
testJdbcParitionColumn("\"THEID\"", "THEID")
1240+
withSQLConf("spark.sql.caseSensitive" -> "false") {
1241+
testJdbcParitionColumn("ThEiD", "THEID")
1242+
}
1243+
testJdbcParitionColumn("THE ID", "THE ID")
1244+
1245+
def testIncorrectJdbcPartitionColumn(partColName: String): Unit = {
1246+
val errMsg = intercept[AnalysisException] {
1247+
testJdbcParitionColumn(partColName, "THEID")
1248+
}.getMessage
1249+
assert(errMsg.contains(s"User-defined partition column $partColName not found " +
1250+
"in the JDBC relation:"))
1251+
}
1252+
1253+
testIncorrectJdbcPartitionColumn("NoExistingColumn")
1254+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
1255+
testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD"))
1256+
}
1257+
}
12091258
}

0 commit comments

Comments
 (0)