Skip to content

Commit 971d318

Browse files
committed
[SPARK-51585][SQL] Oracle dialect supports pushdown datetime functions
### What changes were proposed in this pull request? This PR propose to make Oracle dialect supports pushdown datetime functions. ### Why are the changes needed? Currently, DS V2 pushdown framework pushed the datetime functions with in a common way. But Oracle doesn't support some datetime functions. ### Does this PR introduce _any_ user-facing change? 'No'. This is a new feature for Oracle dialect. ### How was this patch tested? GA. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #50353 from beliefer/SPARK-51585. Authored-by: beliefer <[email protected]> Signed-off-by: beliefer <[email protected]>
1 parent 95a9689 commit 971d318

File tree

2 files changed

+186
-6
lines changed

2 files changed

+186
-6
lines changed

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,28 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
118118
"CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," +
119119
" bonus BINARY_DOUBLE)").executeUpdate()
120120
connection.prepareStatement(
121-
s"""CREATE TABLE pattern_testing_table (
122-
|pattern_testing_col VARCHAR(50)
123-
|)
124-
""".stripMargin
121+
"""CREATE TABLE pattern_testing_table (
122+
|pattern_testing_col VARCHAR(50)
123+
|)
124+
""".stripMargin
125125
).executeUpdate()
126+
connection.prepareStatement(
127+
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
128+
.executeUpdate()
129+
}
130+
131+
override def dataPreparation(connection: Connection): Unit = {
132+
super.dataPreparation(connection)
133+
connection.prepareStatement(
134+
"INSERT INTO datetime VALUES ('amy', TO_DATE('2022-05-19', 'YYYY-MM-DD')," +
135+
" TO_TIMESTAMP('2022-05-19 00:00:00', 'YYYY-MM-DD HH24:MI:SS'))").executeUpdate()
136+
connection.prepareStatement(
137+
"INSERT INTO datetime VALUES ('alex', TO_DATE('2022-05-18', 'YYYY-MM-DD')," +
138+
" TO_TIMESTAMP('2022-05-18 00:00:00', 'YYYY-MM-DD HH24:MI:SS'))").executeUpdate()
139+
// '2022-01-01' is Saturday and is in ISO year 2021.
140+
connection.prepareStatement(
141+
"INSERT INTO datetime VALUES ('tom', TO_DATE('2022-01-01', 'YYYY-MM-DD')," +
142+
" TO_TIMESTAMP('2022-01-01 00:00:00', 'YYYY-MM-DD HH24:MI:SS'))").executeUpdate()
126143
}
127144

128145
override def testUpdateColumnType(tbl: String): Unit = {
@@ -185,4 +202,143 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
185202
checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row("Eason", "Y ")))
186203
}
187204
}
205+
206+
override def testDatetime(tbl: String): Unit = {
207+
val df1 = sql(s"SELECT name FROM $tbl WHERE " +
208+
"dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ")
209+
checkFilterPushed(df1, false)
210+
val rows1 = df1.collect()
211+
assert(rows1.length === 2)
212+
assert(rows1(0).getString(0) === "amy")
213+
assert(rows1(1).getString(0) === "alex")
214+
215+
val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2")
216+
checkFilterPushed(df2, false)
217+
val rows2 = df2.collect()
218+
assert(rows2.length === 2)
219+
assert(rows2(0).getString(0) === "amy")
220+
assert(rows2(1).getString(0) === "alex")
221+
222+
val df3 = sql(s"SELECT name FROM $tbl WHERE month(date1) = 5")
223+
checkFilterPushed(df3)
224+
val rows3 = df3.collect()
225+
assert(rows3.length === 2)
226+
assert(rows3(0).getString(0) === "amy")
227+
assert(rows3(1).getString(0) === "alex")
228+
229+
val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0")
230+
checkFilterPushed(df4)
231+
val rows4 = df4.collect()
232+
assert(rows4.length === 3)
233+
assert(rows4(0).getString(0) === "amy")
234+
assert(rows4(1).getString(0) === "alex")
235+
assert(rows4(2).getString(0) === "tom")
236+
237+
val df5 = sql(s"SELECT name FROM $tbl WHERE " +
238+
"extract(WEEK from date1) > 10 AND extract(YEAR from date1) = 2022")
239+
checkFilterPushed(df5, false)
240+
val rows5 = df5.collect()
241+
assert(rows5.length === 3)
242+
assert(rows5(0).getString(0) === "amy")
243+
assert(rows5(1).getString(0) === "alex")
244+
assert(rows5(2).getString(0) === "tom")
245+
246+
val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " +
247+
"AND datediff(date1, '2022-05-10') > 0")
248+
checkFilterPushed(df6, false)
249+
val rows6 = df6.collect()
250+
assert(rows6.length === 1)
251+
assert(rows6(0).getString(0) === "amy")
252+
253+
val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2")
254+
checkFilterPushed(df7, false)
255+
val rows7 = df7.collect()
256+
assert(rows7.length === 1)
257+
assert(rows7(0).getString(0) === "alex")
258+
259+
withClue("weekofyear") {
260+
val woy = sql(s"SELECT weekofyear(date1) FROM $tbl WHERE name = 'tom'")
261+
.collect().head.getInt(0)
262+
val df = sql(s"SELECT name FROM $tbl WHERE weekofyear(date1) = $woy")
263+
checkFilterPushed(df, false)
264+
val rows = df.collect()
265+
assert(rows.length === 1)
266+
assert(rows(0).getString(0) === "tom")
267+
}
268+
269+
withClue("dayofweek") {
270+
val dow = sql(s"SELECT dayofweek(date1) FROM $tbl WHERE name = 'alex'")
271+
.collect().head.getInt(0)
272+
val df = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = $dow")
273+
checkFilterPushed(df, false)
274+
val rows = df.collect()
275+
assert(rows.length === 1)
276+
assert(rows(0).getString(0) === "alex")
277+
}
278+
279+
withClue("yearofweek") {
280+
val yow = sql(s"SELECT extract(YEAROFWEEK from date1) FROM $tbl WHERE name = 'tom'")
281+
.collect().head.getInt(0)
282+
val df = sql(s"SELECT name FROM $tbl WHERE extract(YEAROFWEEK from date1) = $yow")
283+
checkFilterPushed(df, false)
284+
val rows = df.collect()
285+
assert(rows.length === 1)
286+
assert(rows(0).getString(0) === "tom")
287+
}
288+
289+
withClue("dayofyear") {
290+
val doy = sql(s"SELECT dayofyear(date1) FROM $tbl WHERE name = 'amy'")
291+
.collect().head.getInt(0)
292+
val df = sql(s"SELECT name FROM $tbl WHERE dayofyear(date1) = $doy")
293+
checkFilterPushed(df, false)
294+
val rows = df.collect()
295+
assert(rows.length === 1)
296+
assert(rows(0).getString(0) === "amy")
297+
}
298+
299+
withClue("dayofmonth") {
300+
val dom = sql(s"SELECT dayofmonth(date1) FROM $tbl WHERE name = 'amy'")
301+
.collect().head.getInt(0)
302+
val df = sql(s"SELECT name FROM $tbl WHERE dayofmonth(date1) = $dom")
303+
checkFilterPushed(df)
304+
val rows = df.collect()
305+
assert(rows.length === 1)
306+
assert(rows(0).getString(0) === "amy")
307+
}
308+
309+
withClue("year") {
310+
val year = sql(s"SELECT year(date1) FROM $tbl WHERE name = 'amy'")
311+
.collect().head.getInt(0)
312+
val df = sql(s"SELECT name FROM $tbl WHERE year(date1) = $year")
313+
checkFilterPushed(df)
314+
val rows = df.collect()
315+
assert(rows.length === 3)
316+
assert(rows(0).getString(0) === "amy")
317+
assert(rows5(1).getString(0) === "alex")
318+
assert(rows5(2).getString(0) === "tom")
319+
}
320+
321+
withClue("second") {
322+
val df = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5")
323+
checkFilterPushed(df, false)
324+
val rows = df.collect()
325+
assert(rows.length === 2)
326+
assert(rows(0).getString(0) === "amy")
327+
assert(rows(1).getString(0) === "alex")
328+
}
329+
330+
val df9 = sql(s"SELECT name FROM $tbl WHERE " +
331+
"dayofyear(date1) > 100 order by dayofyear(date1) limit 1")
332+
checkFilterPushed(df9, false)
333+
val rows9 = df9.collect()
334+
assert(rows9.length === 1)
335+
assert(rows9(0).getString(0) === "alex")
336+
337+
val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'")
338+
checkFilterPushed(df10)
339+
val rows10 = df10.collect()
340+
assert(rows10.length === 2)
341+
assert(rows10(0).getString(0) === "amy")
342+
assert(rows10(1).getString(0) === "alex")
343+
}
188344
}

sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.util.control.NonFatal
2424

2525
import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException}
2626
import org.apache.spark.sql.catalyst.SQLConfHelper
27-
import org.apache.spark.sql.connector.expressions.{Expression, Literal}
27+
import org.apache.spark.sql.connector.expressions.{Expression, Extract, Literal}
2828
import org.apache.spark.sql.errors.QueryCompilationErrors
2929
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
3030
import org.apache.spark.sql.jdbc.OracleDialect._
@@ -44,7 +44,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
4444
// scalastyle:on line.size.limit
4545
private val supportedAggregateFunctions =
4646
Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions
47-
private val supportedFunctions = supportedAggregateFunctions
47+
private val supportedFunctions = supportedAggregateFunctions ++ Set("TRUNC")
4848

4949
override def isSupportedFunction(funcName: String): Boolean =
5050
supportedFunctions.contains(funcName)
@@ -56,6 +56,30 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
5656

5757
class OracleSQLBuilder extends JDBCSQLBuilder {
5858

59+
override def visitExtract(extract: Extract): String = {
60+
val field = extract.field
61+
field match {
62+
// YEAR, MONTH, DAY, HOUR, MINUTE are identical on Oracle and Spark for
63+
// both datetime and interval types.
64+
case "YEAR" | "MONTH" | "DAY" | "HOUR" | "MINUTE" =>
65+
super.visitExtract(field, build(extract.source()))
66+
// Oracle does not support the following date fields: DAY_OF_YEAR, WEEK, QUARTER,
67+
// DAY_OF_WEEK, or YEAR_OF_WEEK.
68+
// We can't push down SECOND due to the difference in result types between Spark and
69+
// Oracle. Spark returns decimal(8, 6), but Oracle returns integer.
70+
case _ =>
71+
visitUnexpectedExpr(extract)
72+
}
73+
}
74+
75+
override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
76+
funcName match {
77+
case "TRUNC" =>
78+
s"TRUNC(${inputs(0)}, 'IW')"
79+
case _ => super.visitSQLFunction(funcName, inputs)
80+
}
81+
}
82+
5983
override def visitAggregateFunction(
6084
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
6185
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {

0 commit comments

Comments
 (0)