Skip to content

Commit f65e6f4

Browse files
maropuRobert Kruszewski
authored andcommitted
[SPARK-22814][SQL] Support Date/Timestamp in a JDBC partition column
## What changes were proposed in this pull request? This pr supported Date/Timestamp in a JDBC partition column (a numeric column is only supported in the master). This pr also modified code to verify a partition column type; ``` val jdbcTable = spark.read .option("partitionColumn", "text") .option("lowerBound", "aaa") .option("upperBound", "zzz") .option("numPartitions", 2) .jdbc("jdbc:postgresql:postgres", "t", options) // with this pr org.apache.spark.sql.AnalysisException: Partition column type should be numeric, date, or timestamp, but string found.; at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation$.verifyAndGetNormalizedPartitionColumn(JDBCRelation.scala:165) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation$.columnPartition(JDBCRelation.scala:85) at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:36) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:317) // without this pr java.lang.NumberFormatException: For input string: "aaa" at java.lang.NumberFormatException.forInputString(NumberFormatException.java:65) at java.lang.Long.parseLong(Long.java:589) at java.lang.Long.parseLong(Long.java:631) at scala.collection.immutable.StringLike$class.toLong(StringLike.scala:277) ``` Closes apache#19999 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro <[email protected]> Closes apache#21834 from maropu/SPARK-22814.
1 parent d7a1ba4 commit f65e6f4

File tree

8 files changed

+258
-53
lines changed

8 files changed

+258
-53
lines changed

docs/sql-programming-guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,8 +1345,8 @@ the following case-insensitive options:
13451345
These options must all be specified if any of them is specified. In addition,
13461346
<code>numPartitions</code> must be specified. They describe how to partition the table when
13471347
reading in parallel from multiple workers.
1348-
<code>partitionColumn</code> must be a numeric column from the table in question. Notice
1349-
that <code>lowerBound</code> and <code>upperBound</code> are just used to decide the
1348+
<code>partitionColumn</code> must be a numeric, date, or timestamp column from the table in question.
1349+
Notice that <code>lowerBound</code> and <code>upperBound</code> are just used to decide the
13501350
partition stride, not for filtering the rows in table. So all rows in the table will be
13511351
partitioned and returned. This option applies only to reading.
13521352
</td>

external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
package org.apache.spark.sql.jdbc
1919

20+
import java.math.BigDecimal
2021
import java.sql.{Connection, Date, Timestamp}
2122
import java.util.{Properties, TimeZone}
22-
import java.math.BigDecimal
2323

24-
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode}
24+
import org.apache.spark.sql.{Row, SaveMode}
2525
import org.apache.spark.sql.execution.{RowDataSourceScanExec, WholeStageCodegenExec}
26+
import org.apache.spark.sql.execution.datasources.LogicalRelation
27+
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCRelation}
2628
import org.apache.spark.sql.internal.SQLConf
2729
import org.apache.spark.sql.test.SharedSQLContext
2830
import org.apache.spark.sql.types._
@@ -86,7 +88,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
8688
conn.prepareStatement(
8789
"CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate()
8890
conn.prepareStatement(
89-
"INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate()
91+
"INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)")
92+
.executeUpdate()
9093
conn.commit()
9194

9295
sql(
@@ -108,15 +111,36 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
108111
""".stripMargin.replaceAll("\n", " "))
109112

110113

111-
conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate()
114+
conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))")
115+
.executeUpdate()
112116
conn.prepareStatement(
113117
"INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate()
114118
conn.commit()
115119

116-
conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate()
120+
conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)")
121+
.executeUpdate()
117122
conn.commit()
118-
}
119123

124+
conn.prepareStatement("CREATE TABLE datetimePartitionTest (id NUMBER(10), d DATE, t TIMESTAMP)")
125+
.executeUpdate()
126+
conn.prepareStatement(
127+
"""INSERT INTO datetimePartitionTest VALUES
128+
|(1, {d '2018-07-06'}, {ts '2018-07-06 05:50:00'})
129+
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
130+
conn.prepareStatement(
131+
"""INSERT INTO datetimePartitionTest VALUES
132+
|(2, {d '2018-07-06'}, {ts '2018-07-06 08:10:08'})
133+
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
134+
conn.prepareStatement(
135+
"""INSERT INTO datetimePartitionTest VALUES
136+
|(3, {d '2018-07-08'}, {ts '2018-07-08 13:32:01'})
137+
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
138+
conn.prepareStatement(
139+
"""INSERT INTO datetimePartitionTest VALUES
140+
|(4, {d '2018-07-12'}, {ts '2018-07-12 09:51:15'})
141+
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
142+
conn.commit()
143+
}
120144

121145
test("SPARK-16625 : Importing Oracle numeric types") {
122146
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties)
@@ -399,4 +423,54 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
399423
assert(values.getDouble(0) === 1.1)
400424
assert(values.getFloat(1) === 2.2f)
401425
}
426+
427+
test("SPARK-22814 support date/timestamp types in partitionColumn") {
428+
val expectedResult = Set(
429+
(1, "2018-07-06", "2018-07-06 05:50:00"),
430+
(2, "2018-07-06", "2018-07-06 08:10:08"),
431+
(3, "2018-07-08", "2018-07-08 13:32:01"),
432+
(4, "2018-07-12", "2018-07-12 09:51:15")
433+
).map { case (id, date, timestamp) =>
434+
Row(BigDecimal.valueOf(id), Date.valueOf(date), Timestamp.valueOf(timestamp))
435+
}
436+
437+
// DateType partition column
438+
val df1 = spark.read.format("jdbc")
439+
.option("url", jdbcUrl)
440+
.option("dbtable", "datetimePartitionTest")
441+
.option("partitionColumn", "d")
442+
.option("lowerBound", "2018-07-06")
443+
.option("upperBound", "2018-07-20")
444+
.option("numPartitions", 3)
445+
.load()
446+
447+
df1.logicalPlan match {
448+
case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
449+
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
450+
assert(whereClauses === Set(
451+
""""D" < '2018-07-10' or "D" is null""",
452+
""""D" >= '2018-07-10' AND "D" < '2018-07-14'""",
453+
""""D" >= '2018-07-14'"""))
454+
}
455+
assert(df1.collect.toSet === expectedResult)
456+
457+
// TimestampType partition column
458+
val df2 = spark.read.format("jdbc")
459+
.option("url", jdbcUrl)
460+
.option("dbtable", "datetimePartitionTest")
461+
.option("partitionColumn", "t")
462+
.option("lowerBound", "2018-07-04 03:30:00.0")
463+
.option("upperBound", "2018-07-27 14:11:05.0")
464+
.option("numPartitions", 2)
465+
.load()
466+
467+
df2.logicalPlan match {
468+
case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
469+
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
470+
assert(whereClauses === Set(
471+
""""T" < '2018-07-15 20:50:32.5' or "T" is null""",
472+
""""T" >= '2018-07-15 20:50:32.5'"""))
473+
}
474+
assert(df2.collect.toSet === expectedResult)
475+
}
402476
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ object DateTimeUtils {
9696
}
9797
}
9898

99-
def getThreadLocalDateFormat(): DateFormat = {
99+
def getThreadLocalDateFormat(timeZone: TimeZone): DateFormat = {
100100
val sdf = threadLocalDateFormat.get()
101-
sdf.setTimeZone(defaultTimeZone())
101+
sdf.setTimeZone(timeZone)
102102
sdf
103103
}
104104

@@ -144,7 +144,11 @@ object DateTimeUtils {
144144
}
145145

146146
def dateToString(days: SQLDate): String =
147-
getThreadLocalDateFormat.format(toJavaDate(days))
147+
getThreadLocalDateFormat(defaultTimeZone()).format(toJavaDate(days))
148+
149+
def dateToString(days: SQLDate, timeZone: TimeZone): String = {
150+
getThreadLocalDateFormat(timeZone).format(toJavaDate(days))
151+
}
148152

149153
// Converts Timestamp to string according to Hive TimestampWritable convention.
150154
def timestampToString(us: SQLTimestamp): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ object PartitioningUtils {
410410
val dateTry = Try {
411411
// try and parse the date, if no exception occurs this is a candidate to be resolved as
412412
// DateType
413-
DateTimeUtils.getThreadLocalDateFormat.parse(raw)
413+
DateTimeUtils.getThreadLocalDateFormat(DateTimeUtils.defaultTimeZone()).parse(raw)
414414
// SPARK-23436: Casting the string to date may still return null if a bad Date is provided.
415415
// This can happen since DateFormat.parse may not use the entire text of the given string:
416416
// so if there are extra-characters after the date, it returns correctly.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ class JDBCOptions(
119119
// the column used to partition
120120
val partitionColumn = parameters.get(JDBC_PARTITION_COLUMN)
121121
// the lower bound of partition column
122-
val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong)
122+
val lowerBound = parameters.get(JDBC_LOWER_BOUND)
123123
// the upper bound of the partition column
124-
val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong)
124+
val upperBound = parameters.get(JDBC_UPPER_BOUND)
125125
// numPartitions is also used for data source writing
126126
require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) ||
127127
(partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined &&

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,27 @@
1717

1818
package org.apache.spark.sql.execution.datasources.jdbc
1919

20+
import java.sql.{Date, Timestamp}
21+
2022
import scala.collection.mutable.ArrayBuffer
2123

2224
import org.apache.spark.Partition
2325
import org.apache.spark.internal.Logging
2426
import org.apache.spark.rdd.RDD
2527
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
2628
import org.apache.spark.sql.catalyst.analysis._
29+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2730
import org.apache.spark.sql.jdbc.JdbcDialects
2831
import org.apache.spark.sql.sources._
29-
import org.apache.spark.sql.types.StructType
32+
import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType}
3033
import org.apache.spark.util.Utils
3134

3235
/**
3336
* Instructions on how to partition the table among workers.
3437
*/
3538
private[sql] case class JDBCPartitioningInfo(
3639
column: String,
40+
columnType: DataType,
3741
lowerBound: Long,
3842
upperBound: Long,
3943
numPartitions: Int)
@@ -51,16 +55,43 @@ private[sql] object JDBCRelation extends Logging {
5155
* the rows with null value for the partitions column.
5256
*
5357
* @param schema resolved schema of a JDBC table
54-
* @param partitioning partition information to generate the where clause for each partition
5558
* @param resolver function used to determine if two identifiers are equal
59+
* @param timeZoneId timezone ID to be used if a partition column type is date or timestamp
5660
* @param jdbcOptions JDBC options that contains url
5761
* @return an array of partitions with where clause for each partition
5862
*/
5963
def columnPartition(
6064
schema: StructType,
61-
partitioning: JDBCPartitioningInfo,
6265
resolver: Resolver,
66+
timeZoneId: String,
6367
jdbcOptions: JDBCOptions): Array[Partition] = {
68+
val partitioning = {
69+
import JDBCOptions._
70+
71+
val partitionColumn = jdbcOptions.partitionColumn
72+
val lowerBound = jdbcOptions.lowerBound
73+
val upperBound = jdbcOptions.upperBound
74+
val numPartitions = jdbcOptions.numPartitions
75+
76+
if (partitionColumn.isEmpty) {
77+
assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not " +
78+
s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
79+
null
80+
} else {
81+
assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
82+
s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
83+
s"'$JDBC_NUM_PARTITIONS' are also required")
84+
85+
val (column, columnType) = verifyAndGetNormalizedPartitionColumn(
86+
schema, partitionColumn.get, resolver, jdbcOptions)
87+
88+
val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType)
89+
val upperBoundValue = toInternalBoundValue(upperBound.get, columnType)
90+
JDBCPartitioningInfo(
91+
column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get)
92+
}
93+
}
94+
6495
if (partitioning == null || partitioning.numPartitions <= 1 ||
6596
partitioning.lowerBound == partitioning.upperBound) {
6697
return Array[Partition](JDBCPartition(null, 0))
@@ -72,6 +103,8 @@ private[sql] object JDBCRelation extends Logging {
72103
"Operation not allowed: the lower bound of partitioning column is larger than the upper " +
73104
s"bound. Lower bound: $lowerBound; Upper bound: $upperBound")
74105

106+
val boundValueToString: Long => String =
107+
toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId)
75108
val numPartitions =
76109
if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */
77110
(upperBound - lowerBound) < 0) {
@@ -80,24 +113,25 @@ private[sql] object JDBCRelation extends Logging {
80113
logWarning("The number of partitions is reduced because the specified number of " +
81114
"partitions is less than the difference between upper bound and lower bound. " +
82115
s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " +
83-
s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " +
84-
s"Upper bound: $upperBound.")
116+
s"partitions: ${partitioning.numPartitions}; " +
117+
s"Lower bound: ${boundValueToString(lowerBound)}; " +
118+
s"Upper bound: ${boundValueToString(upperBound)}.")
85119
upperBound - lowerBound
86120
}
87121
// Overflow and silliness can happen if you subtract then divide.
88122
// Here we get a little roundoff, but that's (hopefully) OK.
89123
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
90124

91-
val column = verifyAndGetNormalizedColumnName(
92-
schema, partitioning.column, resolver, jdbcOptions)
93-
94125
var i: Int = 0
95-
var currentValue: Long = lowerBound
126+
val column = partitioning.column
127+
var currentValue = lowerBound
96128
val ans = new ArrayBuffer[Partition]()
97129
while (i < numPartitions) {
98-
val lBound = if (i != 0) s"$column >= $currentValue" else null
130+
val lBoundValue = boundValueToString(currentValue)
131+
val lBound = if (i != 0) s"$column >= $lBoundValue" else null
99132
currentValue += stride
100-
val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
133+
val uBoundValue = boundValueToString(currentValue)
134+
val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null
101135
val whereClause =
102136
if (uBound == null) {
103137
lBound
@@ -109,23 +143,58 @@ private[sql] object JDBCRelation extends Logging {
109143
ans += JDBCPartition(whereClause, i)
110144
i = i + 1
111145
}
112-
ans.toArray
146+
val partitions = ans.toArray
147+
logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " +
148+
partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", "))
149+
partitions
113150
}
114151

115-
// Verify column name based on the JDBC resolved schema
116-
private def verifyAndGetNormalizedColumnName(
152+
// Verify column name and type based on the JDBC resolved schema
153+
private def verifyAndGetNormalizedPartitionColumn(
117154
schema: StructType,
118155
columnName: String,
119156
resolver: Resolver,
120-
jdbcOptions: JDBCOptions): String = {
157+
jdbcOptions: JDBCOptions): (String, DataType) = {
121158
val dialect = JdbcDialects.get(jdbcOptions.url)
122-
schema.map(_.name).find { fieldName =>
123-
resolver(fieldName, columnName) ||
124-
resolver(dialect.quoteIdentifier(fieldName), columnName)
125-
}.map(dialect.quoteIdentifier).getOrElse {
159+
val column = schema.find { f =>
160+
resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName)
161+
}.getOrElse {
126162
throw new AnalysisException(s"User-defined partition column $columnName not " +
127163
s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}")
128164
}
165+
column.dataType match {
166+
case _: NumericType | DateType | TimestampType =>
167+
case _ =>
168+
throw new AnalysisException(
169+
s"Partition column type should be ${NumericType.simpleString}, " +
170+
s"${DateType.catalogString}, or ${TimestampType.catalogString}, but " +
171+
s"${column.dataType.catalogString} found.")
172+
}
173+
(dialect.quoteIdentifier(column.name), column.dataType)
174+
}
175+
176+
private def toInternalBoundValue(value: String, columnType: DataType): Long = columnType match {
177+
case _: NumericType => value.toLong
178+
case DateType => DateTimeUtils.fromJavaDate(Date.valueOf(value)).toLong
179+
case TimestampType => DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(value))
180+
}
181+
182+
private def toBoundValueInWhereClause(
183+
value: Long,
184+
columnType: DataType,
185+
timeZoneId: String): String = {
186+
def dateTimeToString(): String = {
187+
val timeZone = DateTimeUtils.getTimeZone(timeZoneId)
188+
val dateTimeStr = columnType match {
189+
case DateType => DateTimeUtils.dateToString(value.toInt, timeZone)
190+
case TimestampType => DateTimeUtils.timestampToString(value, timeZone)
191+
}
192+
s"'$dateTimeStr'"
193+
}
194+
columnType match {
195+
case _: NumericType => value.toString
196+
case DateType | TimestampType => dateTimeToString()
197+
}
129198
}
130199

131200
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,11 @@ class JdbcRelationProvider extends CreatableRelationProvider
2929
override def createRelation(
3030
sqlContext: SQLContext,
3131
parameters: Map[String, String]): BaseRelation = {
32-
import JDBCOptions._
33-
3432
val jdbcOptions = new JDBCOptions(parameters)
35-
val partitionColumn = jdbcOptions.partitionColumn
36-
val lowerBound = jdbcOptions.lowerBound
37-
val upperBound = jdbcOptions.upperBound
38-
val numPartitions = jdbcOptions.numPartitions
39-
40-
val partitionInfo = if (partitionColumn.isEmpty) {
41-
assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " +
42-
s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
43-
null
44-
} else {
45-
assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
46-
s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
47-
s"'$JDBC_NUM_PARTITIONS' are also required")
48-
JDBCPartitioningInfo(
49-
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
50-
}
5133
val resolver = sqlContext.conf.resolver
34+
val timeZoneId = sqlContext.conf.sessionLocalTimeZone
5235
val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
53-
val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions)
36+
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
5437
JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
5538
}
5639

0 commit comments

Comments
 (0)