diff --git a/narwhals/_duckdb/expr_dt.py b/narwhals/_duckdb/expr_dt.py index 3a929f1594..3dbd7cc1c9 100644 --- a/narwhals/_duckdb/expr_dt.py +++ b/narwhals/_duckdb/expr_dt.py @@ -2,9 +2,9 @@ from typing import TYPE_CHECKING -from duckdb import FunctionExpression +from duckdb import FunctionExpression, SQLExpression -from narwhals._duckdb.utils import UNITS_DICT, fetch_rel_time_zone, lit +from narwhals._duckdb.utils import UNITS_DICT, fetch_rel_time_zone, lit, when from narwhals._duration import parse_interval_string from narwhals._utils import not_implemented @@ -113,17 +113,30 @@ def total_microseconds(self) -> DuckDBExpr: def truncate(self, every: str) -> DuckDBExpr: multiple, unit = parse_interval_string(every) - if multiple != 1: - # https://github.com/duckdb/duckdb/issues/17554 - msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}." - raise ValueError(msg) if unit == "ns": msg = "Truncating to nanoseconds is not yet supported for DuckDB." raise NotImplementedError(msg) - format = lit(UNITS_DICT[unit]) + + interval_str = f"INTERVAL '{multiple} {UNITS_DICT[unit]}'" + tz_str = "(select value from duckdb_settings() where name = 'TimeZone')" def _truncate(expr: Expression) -> Expression: - return FunctionExpression("date_trunc", format, expr) + is_timestamptz = FunctionExpression("typeof", expr).cast("varchar") == lit( + "TIMESTAMP WITH TIME ZONE" + ) + return when( + is_timestamptz, + FunctionExpression( + "time_bucket", + SQLExpression(interval_str), + expr, + SQLExpression(tz_str), + ).cast("TIMESTAMP WITH TIME ZONE"), + ).otherwise( + FunctionExpression("time_bucket", SQLExpression(interval_str), expr).cast( + "timestamp" + ) + ) return self._compliant_expr._with_callable(_truncate)