@@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer
22
22
import org .apache .spark .Partition
23
23
import org .apache .spark .internal .Logging
24
24
import org .apache .spark .rdd .RDD
25
- import org .apache .spark .sql .{DataFrame , Row , SaveMode , SparkSession , SQLContext }
25
+ import org .apache .spark .sql .{AnalysisException , DataFrame , Row , SaveMode , SparkSession , SQLContext }
26
+ import org .apache .spark .sql .catalyst .analysis ._
26
27
import org .apache .spark .sql .jdbc .JdbcDialects
27
28
import org .apache .spark .sql .sources ._
28
29
import org .apache .spark .sql .types .StructType
30
+ import org .apache .spark .util .Utils
29
31
30
32
/**
31
33
* Instructions on how to partition the table among workers.
@@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging {
48
50
* Null value predicate is added to the first partition where clause to include
49
51
* the rows with null value for the partitions column.
50
52
*
53
+ * @param schema resolved schema of a JDBC table
51
54
* @param partitioning partition information to generate the where clause for each partition
55
+ * @param resolver function used to determine if two identifiers are equal
56
+ * @param jdbcOptions JDBC options that contains url
52
57
* @return an array of partitions with where clause for each partition
53
58
*/
54
- def columnPartition (partitioning : JDBCPartitioningInfo ): Array [Partition ] = {
59
+ def columnPartition (
60
+ schema : StructType ,
61
+ partitioning : JDBCPartitioningInfo ,
62
+ resolver : Resolver ,
63
+ jdbcOptions : JDBCOptions ): Array [Partition ] = {
55
64
if (partitioning == null || partitioning.numPartitions <= 1 ||
56
65
partitioning.lowerBound == partitioning.upperBound) {
57
66
return Array [Partition ](JDBCPartition (null , 0 ))
@@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging {
78
87
// Overflow and silliness can happen if you subtract then divide.
79
88
// Here we get a little roundoff, but that's (hopefully) OK.
80
89
val stride : Long = upperBound / numPartitions - lowerBound / numPartitions
81
- val column = partitioning.column
90
+
91
+ val column = verifyAndGetNormalizedColumnName(
92
+ schema, partitioning.column, resolver, jdbcOptions)
93
+
82
94
var i : Int = 0
83
95
var currentValue : Long = lowerBound
84
96
val ans = new ArrayBuffer [Partition ]()
@@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging {
99
111
}
100
112
ans.toArray
101
113
}
114
+
115
+ // Verify column name based on the JDBC resolved schema
116
+ private def verifyAndGetNormalizedColumnName (
117
+ schema : StructType ,
118
+ columnName : String ,
119
+ resolver : Resolver ,
120
+ jdbcOptions : JDBCOptions ): String = {
121
+ val dialect = JdbcDialects .get(jdbcOptions.url)
122
+ schema.map(_.name).find { fieldName =>
123
+ resolver(fieldName, columnName) ||
124
+ resolver(dialect.quoteIdentifier(fieldName), columnName)
125
+ }.map(dialect.quoteIdentifier).getOrElse {
126
+ throw new AnalysisException (s " User-defined partition column $columnName not " +
127
+ s " found in the JDBC relation: ${schema.simpleString(Utils .maxNumToStringFields)}" )
128
+ }
129
+ }
130
+
131
+ /**
132
+ * Takes a (schema, table) specification and returns the table's Catalyst schema.
133
+ * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the
134
+ * custom schema's type.
135
+ *
136
+ * @param resolver function used to determine if two identifiers are equal
137
+ * @param jdbcOptions JDBC options that contains url, table and other information.
138
+ * @return resolved Catalyst schema of a JDBC table
139
+ */
140
+ def getSchema (resolver : Resolver , jdbcOptions : JDBCOptions ): StructType = {
141
+ val tableSchema = JDBCRDD .resolveTable(jdbcOptions)
142
+ jdbcOptions.customSchema match {
143
+ case Some (customSchema) => JdbcUtils .getCustomSchema(
144
+ tableSchema, customSchema, resolver)
145
+ case None => tableSchema
146
+ }
147
+ }
148
+
149
+ /**
150
+ * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation ]] with the schema.
151
+ */
152
+ def apply (
153
+ parts : Array [Partition ],
154
+ jdbcOptions : JDBCOptions )(
155
+ sparkSession : SparkSession ): JDBCRelation = {
156
+ val schema = JDBCRelation .getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
157
+ JDBCRelation (schema, parts, jdbcOptions)(sparkSession)
158
+ }
102
159
}
103
160
104
161
private [sql] case class JDBCRelation (
105
- parts : Array [Partition ], jdbcOptions : JDBCOptions )(@ transient val sparkSession : SparkSession )
162
+ override val schema : StructType ,
163
+ parts : Array [Partition ],
164
+ jdbcOptions : JDBCOptions )(@ transient val sparkSession : SparkSession )
106
165
extends BaseRelation
107
166
with PrunedFilteredScan
108
167
with InsertableRelation {
@@ -111,15 +170,6 @@ private[sql] case class JDBCRelation(
111
170
112
171
override val needConversion : Boolean = false
113
172
114
- override val schema : StructType = {
115
- val tableSchema = JDBCRDD .resolveTable(jdbcOptions)
116
- jdbcOptions.customSchema match {
117
- case Some (customSchema) => JdbcUtils .getCustomSchema(
118
- tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
119
- case None => tableSchema
120
- }
121
- }
122
-
123
173
// Check if JDBCRDD.compileFilter can accept input filters
124
174
override def unhandledFilters (filters : Array [Filter ]): Array [Filter ] = {
125
175
filters.filter(JDBCRDD .compileFilter(_, JdbcDialects .get(jdbcOptions.url)).isEmpty)
0 commit comments