17
17
18
18
package org .apache .spark .sql .execution .datasources .jdbc
19
19
20
+ import java .sql .{Date , Timestamp }
21
+
20
22
import scala .collection .mutable .ArrayBuffer
21
23
22
24
import org .apache .spark .Partition
23
25
import org .apache .spark .internal .Logging
24
26
import org .apache .spark .rdd .RDD
25
27
import org .apache .spark .sql .{AnalysisException , DataFrame , Row , SaveMode , SparkSession , SQLContext }
26
28
import org .apache .spark .sql .catalyst .analysis ._
29
+ import org .apache .spark .sql .catalyst .util .DateTimeUtils
27
30
import org .apache .spark .sql .jdbc .JdbcDialects
28
31
import org .apache .spark .sql .sources ._
29
- import org .apache .spark .sql .types .StructType
32
+ import org .apache .spark .sql .types .{ DataType , DateType , NumericType , StructType , TimestampType }
30
33
import org .apache .spark .util .Utils
31
34
32
35
/**
33
36
* Instructions on how to partition the table among workers.
34
37
*/
35
38
private [sql] case class JDBCPartitioningInfo (
36
39
column : String ,
40
+ columnType : DataType ,
37
41
lowerBound : Long ,
38
42
upperBound : Long ,
39
43
numPartitions : Int )
@@ -51,16 +55,43 @@ private[sql] object JDBCRelation extends Logging {
51
55
* the rows with null value for the partitions column.
52
56
*
53
57
* @param schema resolved schema of a JDBC table
54
- * @param partitioning partition information to generate the where clause for each partition
55
58
* @param resolver function used to determine if two identifiers are equal
59
+ * @param timeZoneId timezone ID to be used if a partition column type is date or timestamp
56
60
* @param jdbcOptions JDBC options that contains url
57
61
* @return an array of partitions with where clause for each partition
58
62
*/
59
63
def columnPartition (
60
64
schema : StructType ,
61
- partitioning : JDBCPartitioningInfo ,
62
65
resolver : Resolver ,
66
+ timeZoneId : String ,
63
67
jdbcOptions : JDBCOptions ): Array [Partition ] = {
68
+ val partitioning = {
69
+ import JDBCOptions ._
70
+
71
+ val partitionColumn = jdbcOptions.partitionColumn
72
+ val lowerBound = jdbcOptions.lowerBound
73
+ val upperBound = jdbcOptions.upperBound
74
+ val numPartitions = jdbcOptions.numPartitions
75
+
76
+ if (partitionColumn.isEmpty) {
77
+ assert(lowerBound.isEmpty && upperBound.isEmpty, " When 'partitionColumn' is not " +
78
+ s " specified, ' $JDBC_LOWER_BOUND' and ' $JDBC_UPPER_BOUND' are expected to be empty " )
79
+ null
80
+ } else {
81
+ assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
82
+ s " When 'partitionColumn' is specified, ' $JDBC_LOWER_BOUND', ' $JDBC_UPPER_BOUND', and " +
83
+ s " ' $JDBC_NUM_PARTITIONS' are also required " )
84
+
85
+ val (column, columnType) = verifyAndGetNormalizedPartitionColumn(
86
+ schema, partitionColumn.get, resolver, jdbcOptions)
87
+
88
+ val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType)
89
+ val upperBoundValue = toInternalBoundValue(upperBound.get, columnType)
90
+ JDBCPartitioningInfo (
91
+ column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get)
92
+ }
93
+ }
94
+
64
95
if (partitioning == null || partitioning.numPartitions <= 1 ||
65
96
partitioning.lowerBound == partitioning.upperBound) {
66
97
return Array [Partition ](JDBCPartition (null , 0 ))
@@ -72,6 +103,8 @@ private[sql] object JDBCRelation extends Logging {
72
103
" Operation not allowed: the lower bound of partitioning column is larger than the upper " +
73
104
s " bound. Lower bound: $lowerBound; Upper bound: $upperBound" )
74
105
106
+ val boundValueToString : Long => String =
107
+ toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId)
75
108
val numPartitions =
76
109
if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */
77
110
(upperBound - lowerBound) < 0 ) {
@@ -80,24 +113,25 @@ private[sql] object JDBCRelation extends Logging {
80
113
logWarning(" The number of partitions is reduced because the specified number of " +
81
114
" partitions is less than the difference between upper bound and lower bound. " +
82
115
s " Updated number of partitions: ${upperBound - lowerBound}; Input number of " +
83
- s " partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " +
84
- s " Upper bound: $upperBound. " )
116
+ s " partitions: ${partitioning.numPartitions}; " +
117
+ s " Lower bound: ${boundValueToString(lowerBound)}; " +
118
+ s " Upper bound: ${boundValueToString(upperBound)}. " )
85
119
upperBound - lowerBound
86
120
}
87
121
// Overflow and silliness can happen if you subtract then divide.
88
122
// Here we get a little roundoff, but that's (hopefully) OK.
89
123
val stride : Long = upperBound / numPartitions - lowerBound / numPartitions
90
124
91
- val column = verifyAndGetNormalizedColumnName(
92
- schema, partitioning.column, resolver, jdbcOptions)
93
-
94
125
var i : Int = 0
95
- var currentValue : Long = lowerBound
126
+ val column = partitioning.column
127
+ var currentValue = lowerBound
96
128
val ans = new ArrayBuffer [Partition ]()
97
129
while (i < numPartitions) {
98
- val lBound = if (i != 0 ) s " $column >= $currentValue" else null
130
+ val lBoundValue = boundValueToString(currentValue)
131
+ val lBound = if (i != 0 ) s " $column >= $lBoundValue" else null
99
132
currentValue += stride
100
- val uBound = if (i != numPartitions - 1 ) s " $column < $currentValue" else null
133
+ val uBoundValue = boundValueToString(currentValue)
134
+ val uBound = if (i != numPartitions - 1 ) s " $column < $uBoundValue" else null
101
135
val whereClause =
102
136
if (uBound == null ) {
103
137
lBound
@@ -109,23 +143,58 @@ private[sql] object JDBCRelation extends Logging {
109
143
ans += JDBCPartition (whereClause, i)
110
144
i = i + 1
111
145
}
112
- ans.toArray
146
+ val partitions = ans.toArray
147
+ logInfo(s " Number of partitions: $numPartitions, WHERE clauses of these partitions: " +
148
+ partitions.map(_.asInstanceOf [JDBCPartition ].whereClause).mkString(" , " ))
149
+ partitions
113
150
}
114
151
115
- // Verify column name based on the JDBC resolved schema
116
- private def verifyAndGetNormalizedColumnName (
152
+ // Verify column name and type based on the JDBC resolved schema
153
+ private def verifyAndGetNormalizedPartitionColumn (
117
154
schema : StructType ,
118
155
columnName : String ,
119
156
resolver : Resolver ,
120
- jdbcOptions : JDBCOptions ): String = {
157
+ jdbcOptions : JDBCOptions ): ( String , DataType ) = {
121
158
val dialect = JdbcDialects .get(jdbcOptions.url)
122
- schema.map(_.name).find { fieldName =>
123
- resolver(fieldName, columnName) ||
124
- resolver(dialect.quoteIdentifier(fieldName), columnName)
125
- }.map(dialect.quoteIdentifier).getOrElse {
159
+ val column = schema.find { f =>
160
+ resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName)
161
+ }.getOrElse {
126
162
throw new AnalysisException (s " User-defined partition column $columnName not " +
127
163
s " found in the JDBC relation: ${schema.simpleString(Utils .maxNumToStringFields)}" )
128
164
}
165
+ column.dataType match {
166
+ case _ : NumericType | DateType | TimestampType =>
167
+ case _ =>
168
+ throw new AnalysisException (
169
+ s " Partition column type should be ${NumericType .simpleString}, " +
170
+ s " ${DateType .catalogString}, or ${TimestampType .catalogString}, but " +
171
+ s " ${column.dataType.catalogString} found. " )
172
+ }
173
+ (dialect.quoteIdentifier(column.name), column.dataType)
174
+ }
175
+
176
+ private def toInternalBoundValue (value : String , columnType : DataType ): Long = columnType match {
177
+ case _ : NumericType => value.toLong
178
+ case DateType => DateTimeUtils .fromJavaDate(Date .valueOf(value)).toLong
179
+ case TimestampType => DateTimeUtils .fromJavaTimestamp(Timestamp .valueOf(value))
180
+ }
181
+
182
+ private def toBoundValueInWhereClause (
183
+ value : Long ,
184
+ columnType : DataType ,
185
+ timeZoneId : String ): String = {
186
+ def dateTimeToString (): String = {
187
+ val timeZone = DateTimeUtils .getTimeZone(timeZoneId)
188
+ val dateTimeStr = columnType match {
189
+ case DateType => DateTimeUtils .dateToString(value.toInt, timeZone)
190
+ case TimestampType => DateTimeUtils .timestampToString(value, timeZone)
191
+ }
192
+ s " ' $dateTimeStr' "
193
+ }
194
+ columnType match {
195
+ case _ : NumericType => value.toString
196
+ case DateType | TimestampType => dateTimeToString()
197
+ }
129
198
}
130
199
131
200
/**
0 commit comments