Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions narwhals/_duckdb/expr_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from typing import TYPE_CHECKING

from duckdb import FunctionExpression
from duckdb import FunctionExpression, SQLExpression

from narwhals._duckdb.utils import UNITS_DICT, fetch_rel_time_zone, lit
from narwhals._duckdb.utils import UNITS_DICT, fetch_rel_time_zone, lit, when
from narwhals._duration import parse_interval_string
from narwhals._utils import not_implemented

Expand Down Expand Up @@ -113,17 +113,30 @@ def total_microseconds(self) -> DuckDBExpr:

def truncate(self, every: str) -> DuckDBExpr:
multiple, unit = parse_interval_string(every)
if multiple != 1:
# https://github.com/duckdb/duckdb/issues/17554
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for DuckDB."
raise NotImplementedError(msg)
format = lit(UNITS_DICT[unit])

interval_str = f"INTERVAL '{multiple} {UNITS_DICT[unit]}'"
tz_str = "(select value from duckdb_settings() where name = 'TimeZone')"

def _truncate(expr: Expression) -> Expression:
return FunctionExpression("date_trunc", format, expr)
is_timestamptz = FunctionExpression("typeof", expr).cast("varchar") == lit(
"TIMESTAMP WITH TIME ZONE"
)
return when(
is_timestamptz,
FunctionExpression(
"time_bucket",
SQLExpression(interval_str),
expr,
SQLExpression(tz_str),
).cast("TIMESTAMP WITH TIME ZONE"),
).otherwise(
FunctionExpression("time_bucket", SQLExpression(interval_str), expr).cast(
"timestamp"
)
)

return self._compliant_expr._with_callable(_truncate)

Expand Down
Loading