Skip to content

Commit ea8b6fd

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-53107][SQL] Implement the time_trunc function in Scala
### What changes were proposed in this pull request? Implement the `time_trunc` function in Scala API. ### Why are the changes needed? Expand API support for the `TimeTrunc` expression. ### Does this PR introduce _any_ user-facing change? Yes, the new function is now available in Scala API. ### How was this patch tested? Added appropriate Scala function tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51823 from uros-db/scala-time_trunc. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 70c2008 commit ea8b6fd

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

python/pyspark/sql/tests/test_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def test_function_parity(self):
8181
missing_in_py = jvm_fn_set.difference(py_fn_set)
8282

8383
# Functions that we expect to be missing in python until they are added to pyspark
84-
expected_missing_in_py = set()
84+
expected_missing_in_py = set(
85+
# TODO(SPARK-53107): Implement the time_trunc function in Python
86+
["time_trunc"]
87+
)
8588

8689
self.assertEqual(
8790
expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected"

sql/api/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6292,6 +6292,27 @@ object functions {
62926292
def timestamp_add(unit: String, quantity: Column, ts: Column): Column =
62936293
Column.internalFn("timestampadd", lit(unit), quantity, ts)
62946294

6295+
/**
6296+
* Returns `time` truncated to the `unit`.
6297+
*
6298+
* @param unit
6299+
* A STRING representing the unit to truncate the time to. Supported units are: "HOUR",
6300+
* "MINUTE", "SECOND", "MILLISECOND", and "MICROSECOND". The unit is case-insensitive.
6301+
* @param time
6302+
* A TIME to truncate.
6303+
* @return
6304+
* A TIME truncated to the specified unit.
6305+
* @note
6306+
* If any of the inputs is `NULL`, the result is `NULL`.
6307+
* @throws IllegalArgumentException
6308+
* If the `unit` is not supported.
6309+
* @group datetime_funcs
6310+
* @since 4.1.0
6311+
*/
6312+
def time_trunc(unit: Column, time: Column): Column = {
6313+
Column.fn("time_trunc", unit, time)
6314+
}
6315+
62956316
/**
62966317
* Parses the `timestamp` expression with the `format` expression to a timestamp without time
62976318
* zone. Returns null with invalid input.

sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql
2020
import java.time.LocalTime
2121
import java.time.temporal.ChronoUnit
2222

23-
import org.apache.spark.{SparkConf, SparkDateTimeException}
23+
import org.apache.spark.{SparkConf, SparkDateTimeException, SparkIllegalArgumentException}
2424
import org.apache.spark.sql.functions._
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.test.SharedSparkSession
@@ -241,6 +241,60 @@ abstract class TimeFunctionsSuiteBase extends QueryTest with SharedSparkSession
241241
checkAnswer(result2, expected)
242242
}
243243

244+
test("SPARK-53107: time_trunc function") {
245+
// Input data for the function (including null values).
246+
val schema = StructType(Seq(
247+
StructField("unit", StringType),
248+
StructField("time", TimeType())
249+
))
250+
val data = Seq(
251+
Row("HOUR", LocalTime.parse("00:00:00")),
252+
Row("second", LocalTime.parse("01:02:03.4")),
253+
Row("MicroSecond", LocalTime.parse("23:59:59.999999")),
254+
Row(null, LocalTime.parse("01:02:03")),
255+
Row("MiNuTe", null),
256+
Row(null, null)
257+
)
258+
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
259+
260+
// Test the function using both `selectExpr` and `select`.
261+
val result1 = df.selectExpr(
262+
"time_trunc(unit, time)"
263+
)
264+
val result2 = df.select(
265+
time_trunc(col("unit"), col("time"))
266+
)
267+
// Check that both methods produce the same result.
268+
checkAnswer(result1, result2)
269+
270+
// Expected output of the function.
271+
val expected = Seq(
272+
"00:00:00",
273+
"01:02:03",
274+
"23:59:59.999999",
275+
null,
276+
null,
277+
null
278+
).toDF("timeString").select(col("timeString").cast("time"))
279+
// Check that the results match the expected output.
280+
checkAnswer(result1, expected)
281+
checkAnswer(result2, expected)
282+
283+
// Error is thrown for malformed input.
284+
val invalidUnitDF = Seq(("invalid_unit", LocalTime.parse("01:02:03"))).toDF("unit", "time")
285+
checkError(
286+
exception = intercept[SparkIllegalArgumentException] {
287+
invalidUnitDF.select(time_trunc(col("unit"), col("time"))).collect()
288+
},
289+
condition = "INVALID_PARAMETER_VALUE.TIME_UNIT",
290+
parameters = Map(
291+
"functionName" -> "`time_trunc`",
292+
"parameter" -> "`unit`",
293+
"invalidValue" -> "'invalid_unit'"
294+
)
295+
)
296+
}
297+
244298
test("SPARK-52883: to_time function without format") {
245299
// Input data for the function.
246300
val schema = StructType(Seq(

0 commit comments

Comments
 (0)