Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 6e36d8d

Browse files
youngbinkgatorsmile
authored andcommitted
[SPARK-22829] Add new built-in function date_trunc()
## What changes were proposed in this pull request? Adding date_trunc() as a built-in function. `date_trunc` is common in other databases, but Spark or Hive does not have support for this. `date_trunc` is commonly used by data scientists and business intelligence application such as Superset (https://github.com/apache/incubator-superset). We do have `trunc` but this only works with 'MONTH' and 'YEAR' level on the DateType input. date_trunc() in other databases: AWS Redshift: http://docs.aws.amazon.com/redshift/latest/dg/r_DATE_TRUNC.html PostgreSQL: https://www.postgresql.org/docs/9.1/static/functions-datetime.html Presto: https://prestodb.io/docs/current/functions/datetime.html ## How was this patch tested? Unit tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Youngbin Kim <[email protected]> Closes apache#20015 from youngbink/date_trunc.
1 parent 3a7494d commit 6e36d8d

File tree

8 files changed

+445
-52
lines changed

8 files changed

+445
-52
lines changed

python/pyspark/sql/functions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def trunc(date, format):
10991099
"""
11001100
Returns date truncated to the unit specified by the format.
11011101
1102-
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
1102+
:param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm'
11031103
11041104
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
11051105
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
@@ -1111,6 +1111,24 @@ def trunc(date, format):
11111111
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
11121112

11131113

1114+
@since(2.3)
1115+
def date_trunc(format, timestamp):
1116+
"""
1117+
Returns timestamp truncated to the unit specified by the format.
1118+
1119+
:param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm',
1120+
'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter'
1121+
1122+
>>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
1123+
>>> df.select(date_trunc('year', df.t).alias('year')).collect()
1124+
[Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
1125+
>>> df.select(date_trunc('mon', df.t).alias('month')).collect()
1126+
[Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
1127+
"""
1128+
sc = SparkContext._active_spark_context
1129+
return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp)))
1130+
1131+
11141132
@since(1.5)
11151133
def next_day(date, dayOfWeek):
11161134
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ object FunctionRegistry {
392392
expression[ToUnixTimestamp]("to_unix_timestamp"),
393393
expression[ToUTCTimestamp]("to_utc_timestamp"),
394394
expression[TruncDate]("trunc"),
395+
expression[TruncTimestamp]("date_trunc"),
395396
expression[UnixTimestamp]("unix_timestamp"),
396397
expression[DayOfWeek]("dayofweek"),
397398
expression[WeekOfYear]("weekofyear"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 132 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,87 +1294,181 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
12941294
override def dataType: DataType = TimestampType
12951295
}
12961296

1297-
/**
1298-
* Returns date truncated to the unit specified by the format.
1299-
*/
1300-
// scalastyle:off line.size.limit
1301-
@ExpressionDescription(
1302-
usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.",
1303-
examples = """
1304-
Examples:
1305-
> SELECT _FUNC_('2009-02-12', 'MM');
1306-
2009-02-01
1307-
> SELECT _FUNC_('2015-10-27', 'YEAR');
1308-
2015-01-01
1309-
""",
1310-
since = "1.5.0")
1311-
// scalastyle:on line.size.limit
1312-
case class TruncDate(date: Expression, format: Expression)
1313-
extends BinaryExpression with ImplicitCastInputTypes {
1314-
override def left: Expression = date
1315-
override def right: Expression = format
1316-
1317-
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
1318-
override def dataType: DataType = DateType
1297+
trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
1298+
val instant: Expression
1299+
val format: Expression
13191300
override def nullable: Boolean = true
1320-
override def prettyName: String = "trunc"
13211301

13221302
private lazy val truncLevel: Int =
13231303
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
13241304

1325-
override def eval(input: InternalRow): Any = {
1305+
/**
1306+
* @param input internalRow (time)
1307+
* @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input)
1308+
* @param truncFunc function: (time, level) => time
1309+
*/
1310+
protected def evalHelper(input: InternalRow, maxLevel: Int)(
1311+
truncFunc: (Any, Int) => Any): Any = {
13261312
val level = if (format.foldable) {
13271313
truncLevel
13281314
} else {
13291315
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
13301316
}
1331-
if (level == -1) {
1332-
// unknown format
1317+
if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) {
1318+
// unknown format or too large level
13331319
null
13341320
} else {
1335-
val d = date.eval(input)
1336-
if (d == null) {
1321+
val t = instant.eval(input)
1322+
if (t == null) {
13371323
null
13381324
} else {
1339-
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
1325+
truncFunc(t, level)
13401326
}
13411327
}
13421328
}
13431329

1344-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1330+
protected def codeGenHelper(
1331+
ctx: CodegenContext,
1332+
ev: ExprCode,
1333+
maxLevel: Int,
1334+
orderReversed: Boolean = false)(
1335+
truncFunc: (String, String) => String)
1336+
: ExprCode = {
13451337
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
13461338

13471339
if (format.foldable) {
1348-
if (truncLevel == -1) {
1340+
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
13491341
ev.copy(code = s"""
13501342
boolean ${ev.isNull} = true;
13511343
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
13521344
} else {
1353-
val d = date.genCode(ctx)
1345+
val t = instant.genCode(ctx)
1346+
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
13541347
ev.copy(code = s"""
1355-
${d.code}
1356-
boolean ${ev.isNull} = ${d.isNull};
1348+
${t.code}
1349+
boolean ${ev.isNull} = ${t.isNull};
13571350
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
13581351
if (!${ev.isNull}) {
1359-
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
1352+
${ev.value} = $dtu.$truncFuncStr;
13601353
}""")
13611354
}
13621355
} else {
1363-
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
1356+
nullSafeCodeGen(ctx, ev, (left, right) => {
13641357
val form = ctx.freshName("form")
1358+
val (dateVal, fmt) = if (orderReversed) {
1359+
(right, left)
1360+
} else {
1361+
(left, right)
1362+
}
1363+
val truncFuncStr = truncFunc(dateVal, form)
13651364
s"""
13661365
int $form = $dtu.parseTruncLevel($fmt);
1367-
if ($form == -1) {
1366+
if ($form == -1 || $form > $maxLevel) {
13681367
${ev.isNull} = true;
13691368
} else {
1370-
${ev.value} = $dtu.truncDate($dateVal, $form);
1369+
${ev.value} = $dtu.$truncFuncStr
13711370
}
13721371
"""
13731372
})
13741373
}
13751374
}
13761375
}
13771376

1377+
/**
1378+
* Returns date truncated to the unit specified by the format.
1379+
*/
1380+
// scalastyle:off line.size.limit
1381+
@ExpressionDescription(
1382+
usage = """
1383+
_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
1384+
`fmt` should be one of ["year", "yyyy", "yy", "mon", "month", "mm"]
1385+
""",
1386+
examples = """
1387+
Examples:
1388+
> SELECT _FUNC_('2009-02-12', 'MM');
1389+
2009-02-01
1390+
> SELECT _FUNC_('2015-10-27', 'YEAR');
1391+
2015-01-01
1392+
""",
1393+
since = "1.5.0")
1394+
// scalastyle:on line.size.limit
1395+
case class TruncDate(date: Expression, format: Expression)
1396+
extends TruncInstant {
1397+
override def left: Expression = date
1398+
override def right: Expression = format
1399+
1400+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
1401+
override def dataType: DataType = DateType
1402+
override def prettyName: String = "trunc"
1403+
override val instant = date
1404+
1405+
override def eval(input: InternalRow): Any = {
1406+
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) =>
1407+
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
1408+
}
1409+
}
1410+
1411+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1412+
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) =>
1413+
s"truncDate($date, $fmt);"
1414+
}
1415+
}
1416+
}
1417+
1418+
/**
1419+
* Returns timestamp truncated to the unit specified by the format.
1420+
*/
1421+
// scalastyle:off line.size.limit
1422+
@ExpressionDescription(
1423+
usage = """
1424+
_FUNC_(fmt, ts) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`.
1425+
`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"]
1426+
""",
1427+
examples = """
1428+
Examples:
1429+
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
1430+
2015-01-01T00:00:00
1431+
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
1432+
2015-03-01T00:00:00
1433+
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
1434+
2015-03-05T00:00:00
1435+
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
1436+
2015-03-05T09:00:00
1437+
""",
1438+
since = "2.3.0")
1439+
// scalastyle:on line.size.limit
1440+
case class TruncTimestamp(
1441+
format: Expression,
1442+
timestamp: Expression,
1443+
timeZoneId: Option[String] = None)
1444+
extends TruncInstant with TimeZoneAwareExpression {
1445+
override def left: Expression = format
1446+
override def right: Expression = timestamp
1447+
1448+
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
1449+
override def dataType: TimestampType = TimestampType
1450+
override def prettyName: String = "date_trunc"
1451+
override val instant = timestamp
1452+
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
1453+
copy(timeZoneId = Option(timeZoneId))
1454+
1455+
def this(format: Expression, timestamp: Expression) = this(format, timestamp, None)
1456+
1457+
override def eval(input: InternalRow): Any = {
1458+
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, level: Int) =>
1459+
DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone)
1460+
}
1461+
}
1462+
1463+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1464+
val tz = ctx.addReferenceObj("timeZone", timeZone)
1465+
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) {
1466+
(date: String, fmt: String) =>
1467+
s"truncTimestamp($date, $fmt, $tz);"
1468+
}
1469+
}
1470+
}
1471+
13781472
/**
13791473
* Returns the number of days from startDate to endDate.
13801474
*/

0 commit comments

Comments
 (0)