Skip to content

Commit 9c50156

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-53108][SQL] Implement the time_diff function in Scala
### What changes were proposed in this pull request? Implement the `time_diff` function in Scala API. ### Why are the changes needed? Expand API support for the `TimeDiff` expression. ### Does this PR introduce _any_ user-facing change? Yes, the new function is now available in Scala API. ### How was this patch tested? Added appropriate Scala function tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51826 from uros-db/scala-time_diff. Authored-by: Uros Bojanic <uros.bojanic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent f93eff3 commit 9c50156

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

python/pyspark/sql/tests/test_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def test_function_parity(self):
8181
missing_in_py = jvm_fn_set.difference(py_fn_set)
8282

8383
# Functions that we expect to be missing in python until they are added to pyspark
84-
expected_missing_in_py = set()
84+
expected_missing_in_py = set(
85+
# TODO(SPARK-53108): Implement the time_diff function in Python
86+
["time_diff"]
87+
)
8588

8689
self.assertEqual(
8790
expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected"

sql/api/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6292,6 +6292,28 @@ object functions {
62926292
def timestamp_add(unit: String, quantity: Column, ts: Column): Column =
62936293
Column.internalFn("timestampadd", lit(unit), quantity, ts)
62946294

6295+
/**
6296+
* Returns the difference between two times, measured in specified units. Throws a
6297+
* SparkIllegalArgumentException, in case the specified unit is not supported.
6298+
*
6299+
* @param unit
6300+
* A STRING representing the unit of the time difference. Supported units are: "HOUR",
6301+
* "MINUTE", "SECOND", "MILLISECOND", and "MICROSECOND". The unit is case-insensitive.
6302+
* @param start
6303+
* A starting TIME.
6304+
* @param end
6305+
* An ending TIME.
6306+
* @return
6307+
* The difference between `end` and `start` times, measured in specified units.
6308+
* @note
6309+
* If any of the inputs is `NULL`, the result is `NULL`.
6310+
* @group datetime_funcs
6311+
* @since 4.1.0
6312+
*/
6313+
def time_diff(unit: Column, start: Column, end: Column): Column = {
6314+
Column.fn("time_diff", unit, start, end)
6315+
}
6316+
62956317
/**
62966318
* Returns `time` truncated to the `unit`.
62976319
*

sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,75 @@ abstract class TimeFunctionsSuiteBase extends QueryTest with SharedSparkSession
290290
checkAnswer(result2, expected)
291291
}
292292

293+
test("SPARK-53108: time_diff function") {
294+
// Input data for the function.
295+
val schema = StructType(Seq(
296+
StructField("unit", StringType, nullable = false),
297+
StructField("start", TimeType(), nullable = false),
298+
StructField("end", TimeType(), nullable = false)
299+
))
300+
val data = Seq(
301+
Row("HOUR", LocalTime.parse("20:30:29"), LocalTime.parse("21:30:28")),
302+
Row("second", LocalTime.parse("09:32:05.359123"), LocalTime.parse("17:23:49.906152")),
303+
Row("MicroSecond", LocalTime.parse("09:32:05.359123"), LocalTime.parse("17:23:49.906152"))
304+
)
305+
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
306+
307+
// Test the function using both `selectExpr` and `select`.
308+
val result1 = df.selectExpr(
309+
"time_diff(unit, start, end)"
310+
)
311+
val result2 = df.select(
312+
time_diff(col("unit"), col("start"), col("end"))
313+
)
314+
// Check that both methods produce the same result.
315+
checkAnswer(result1, result2)
316+
317+
// Expected output of the function.
318+
val expected = Seq(
319+
0,
320+
28304,
321+
28304547029L
322+
).toDF("diff").select(col("diff"))
323+
// Check that the results match the expected output.
324+
checkAnswer(result1, expected)
325+
checkAnswer(result2, expected)
326+
327+
// NULL result is returned for any NULL input.
328+
val nullInputDF = Seq(
329+
(null, LocalTime.parse("01:02:03"), LocalTime.parse("01:02:03")),
330+
("HOUR", null, LocalTime.parse("01:02:03")),
331+
("HOUR", LocalTime.parse("01:02:03"), null),
332+
("HOUR", null, null),
333+
(null, LocalTime.parse("01:02:03"), null),
334+
(null, null, LocalTime.parse("01:02:03")),
335+
(null, null, null)
336+
).toDF("unit", "start", "end")
337+
val nullResult = Seq[Integer](
338+
null, null, null, null, null, null, null
339+
).toDF("diff").select(col("diff"))
340+
checkAnswer(
341+
nullInputDF.select(time_diff(col("unit"), col("start"), col("end"))),
342+
nullResult
343+
)
344+
345+
// Error is thrown for malformed input.
346+
val invalidUnitDF = Seq(
347+
("invalid_unit", LocalTime.parse("01:02:03"), LocalTime.parse("01:02:03"))
348+
).toDF("unit", "start", "end")
349+
checkError(
350+
exception = intercept[SparkIllegalArgumentException] {
351+
invalidUnitDF.select(time_diff(col("unit"), col("start"), col("end"))).collect()
352+
},
353+
condition = "INVALID_PARAMETER_VALUE.TIME_UNIT",
354+
parameters = Map(
355+
"functionName" -> "`time_diff`",
356+
"parameter" -> "`unit`",
357+
"invalidValue" -> "'invalid_unit'"
358+
)
359+
)
360+
}
361+
293362
test("SPARK-53107: time_trunc function") {
294363
// Input data for the function (including null values).
295364
val schema = StructType(Seq(

0 commit comments

Comments
 (0)