Skip to content

Commit 8e65c33

Browse files
authored
enh: reflect connection time zone in Datetime for DuckDB, support convert_time_zone and replace_time_zone (#2590)
1 parent d3acd98 commit 8e65c33

File tree

8 files changed

+168
-27
lines changed

8 files changed

+168
-27
lines changed

narwhals/_duckdb/dataframe.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from duckdb import FunctionExpression, StarExpression
1010

1111
from narwhals._duckdb.utils import (
12+
DeferredTimeZone,
1213
col,
1314
evaluate_exprs,
1415
generate_partition_by_sql,
@@ -34,6 +35,7 @@
3435
import pandas as pd
3536
import pyarrow as pa
3637
from duckdb import Expression
38+
from duckdb.typing import DuckDBPyType
3739
from typing_extensions import Self, TypeIs
3840

3941
from narwhals._compliant.typing import CompliantDataFrameAny
@@ -70,7 +72,7 @@ def __init__(
7072
self._native_frame: duckdb.DuckDBPyRelation = df
7173
self._version = version
7274
self._backend_version = backend_version
73-
self._cached_schema: dict[str, DType] | None = None
75+
self._cached_native_schema: dict[str, DuckDBPyType] | None = None
7476
self._cached_columns: list[str] | None = None
7577
validate_backend_version(self._implementation, self._backend_version)
7678

@@ -212,23 +214,25 @@ def filter(self, predicate: DuckDBExpr) -> Self:
212214

213215
@property
214216
def schema(self) -> dict[str, DType]:
215-
if self._cached_schema is None:
216-
# Note: prefer `self._cached_schema` over `functools.cached_property`
217+
if self._cached_native_schema is None:
218+
# Note: prefer `self._cached_native_schema` over `functools.cached_property`
217219
# due to Python3.13 failures.
218-
self._cached_schema = {
219-
column_name: native_to_narwhals_dtype(duckdb_dtype, self._version)
220-
for column_name, duckdb_dtype in zip(
221-
self.native.columns, self.native.types
222-
)
223-
}
224-
return self._cached_schema
220+
self._cached_native_schema = dict(zip(self.columns, self.native.types))
221+
222+
deferred_time_zone = DeferredTimeZone(self.native)
223+
return {
224+
column_name: native_to_narwhals_dtype(
225+
duckdb_dtype, self._version, deferred_time_zone
226+
)
227+
for column_name, duckdb_dtype in zip(self.native.columns, self.native.types)
228+
}
225229

226230
@property
227231
def columns(self) -> list[str]:
228232
if self._cached_columns is None:
229233
self._cached_columns = (
230234
list(self.schema)
231-
if self._cached_schema is not None
235+
if self._cached_native_schema is not None
232236
else self.native.columns
233237
)
234238
return self._cached_columns

narwhals/_duckdb/expr_dt.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Sequence
44

55
from duckdb import FunctionExpression
66

7-
from narwhals._duckdb.utils import UNITS_DICT, lit
7+
from narwhals._duckdb.utils import UNITS_DICT, fetch_rel_time_zone, lit
88
from narwhals._duration import parse_interval_string
99
from narwhals.utils import not_implemented
1010

1111
if TYPE_CHECKING:
1212
from duckdb import Expression
1313

14+
from narwhals._duckdb.dataframe import DuckDBLazyFrame
1415
from narwhals._duckdb.expr import DuckDBExpr
1516

1617

@@ -124,13 +125,36 @@ def _truncate(expr: Expression) -> Expression:
124125

125126
return self._compliant_expr._with_callable(_truncate)
126127

128+
def _no_op_time_zone(self, time_zone: str) -> DuckDBExpr:
129+
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
130+
native_series_list = self._compliant_expr(df)
131+
conn_time_zone = fetch_rel_time_zone(df.native)
132+
if conn_time_zone != time_zone:
133+
msg = (
134+
"DuckDB stores the time zone in the connection, rather than in the "
135+
f"data type, so changing the timezone to anything other than {conn_time_zone} "
136+
" (the current connection time zone) is not supported."
137+
)
138+
raise NotImplementedError(msg)
139+
return native_series_list
140+
141+
return self._compliant_expr.__class__(
142+
func,
143+
evaluate_output_names=self._compliant_expr._evaluate_output_names,
144+
alias_output_names=self._compliant_expr._alias_output_names,
145+
backend_version=self._compliant_expr._backend_version,
146+
version=self._compliant_expr._version,
147+
)
148+
149+
def convert_time_zone(self, time_zone: str) -> DuckDBExpr:
150+
return self._no_op_time_zone(time_zone)
151+
127152
def replace_time_zone(self, time_zone: str | None) -> DuckDBExpr:
128153
if time_zone is None:
129154
return self._compliant_expr._with_callable(
130155
lambda _input: _input.cast("timestamp")
131156
)
132-
else: # pragma: no cover
133-
msg = "`replace_time_zone` with non-null `time_zone` not yet implemented for duckdb"
134-
raise NotImplementedError(msg)
157+
else:
158+
return self._no_op_time_zone(time_zone)
135159

136160
total_nanoseconds = not_implemented()

narwhals/_duckdb/series.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from narwhals._duckdb.utils import native_to_narwhals_dtype
5+
from narwhals._duckdb.utils import DeferredTimeZone, native_to_narwhals_dtype
66
from narwhals.dependencies import get_duckdb
77

88
if TYPE_CHECKING:
@@ -28,7 +28,11 @@ def __native_namespace__(self) -> ModuleType:
2828

2929
@property
3030
def dtype(self) -> DType:
31-
return native_to_narwhals_dtype(self._native_series.types[0], self._version)
31+
return native_to_narwhals_dtype(
32+
self._native_series.types[0],
33+
self._version,
34+
DeferredTimeZone(self._native_series),
35+
)
3236

3337
def __getattr__(self, attr: str) -> Never:
3438
msg = ( # pragma: no cover

narwhals/_duckdb/utils.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from narwhals.utils import Version, isinstance_or_issubclass
99

1010
if TYPE_CHECKING:
11-
from duckdb import Expression
11+
from duckdb import DuckDBPyRelation, Expression
1212
from duckdb.typing import DuckDBPyType
1313

1414
from narwhals._duckdb.dataframe import DuckDBLazyFrame
@@ -95,21 +95,56 @@ def evaluate_exprs(
9595
return native_results
9696

9797

98-
def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DType:
98+
class DeferredTimeZone:
99+
"""Object which gets passed between `native_to_narwhals_dtype` calls.
100+
101+
DuckDB stores the time zone in the connection, rather than in the dtypes, so
102+
this ensures that when calculating the schema of a dataframe with multiple
103+
timezone-aware columns, that the connection's time zone is only fetched once.
104+
105+
Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because
106+
the time zone can be modified after `DuckDBLazyFrame` creation:
107+
108+
```python
109+
df = nw.from_native(rel)
110+
print(df.collect_schema())
111+
rel.query("set timezone = 'Asia/Kolkata'")
112+
print(df.collect_schema()) # should change to reflect new time zone
113+
```
114+
"""
115+
116+
_cached_time_zone: str | None = None
117+
118+
def __init__(self, rel: DuckDBPyRelation) -> None:
119+
self._rel = rel
120+
121+
@property
122+
def time_zone(self) -> str:
123+
"""Fetch relation time zone (if it wasn't calculated already)."""
124+
if self._cached_time_zone is None:
125+
self._cached_time_zone = fetch_rel_time_zone(self._rel)
126+
return self._cached_time_zone
127+
128+
129+
def native_to_narwhals_dtype(
130+
duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone
131+
) -> DType:
99132
duckdb_dtype_id = duckdb_dtype.id
100133
dtypes = version.dtypes
101134

102135
# Handle nested data types first
103136
if duckdb_dtype_id == "list":
104-
return dtypes.List(native_to_narwhals_dtype(duckdb_dtype.child, version=version))
137+
return dtypes.List(
138+
native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone)
139+
)
105140

106141
if duckdb_dtype_id == "struct":
107142
children = duckdb_dtype.children
108143
return dtypes.Struct(
109144
[
110145
dtypes.Field(
111146
name=child[0],
112-
dtype=native_to_narwhals_dtype(child[1], version=version),
147+
dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone),
113148
)
114149
for child in children
115150
]
@@ -123,7 +158,7 @@ def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DT
123158
child, size = child[1].children
124159
shape.insert(0, size[1])
125160

126-
inner = native_to_narwhals_dtype(child[1], version=version)
161+
inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone)
127162
return dtypes.Array(inner=inner, shape=tuple(shape))
128163

129164
if duckdb_dtype_id == "enum":
@@ -132,9 +167,20 @@ def native_to_narwhals_dtype(duckdb_dtype: DuckDBPyType, version: Version) -> DT
132167
categories = duckdb_dtype.children[0][1]
133168
return dtypes.Enum(categories=categories)
134169

170+
if duckdb_dtype_id == "timestamp with time zone":
171+
return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)
172+
135173
return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)
136174

137175

176+
def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str:
177+
result = rel.query(
178+
"duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'"
179+
).fetchone()
180+
assert result is not None # noqa: S101
181+
return result[0]
182+
183+
138184
@lru_cache(maxsize=16)
139185
def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
140186
dtypes = version.dtypes
@@ -154,9 +200,6 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version)
154200
"varchar": dtypes.String(),
155201
"date": dtypes.Date(),
156202
"timestamp": dtypes.Datetime(),
157-
# TODO(marco): is UTC correct, or should we be getting the connection timezone?
158-
# https://github.com/narwhals-dev/narwhals/issues/2165
159-
"timestamp with time zone": dtypes.Datetime(time_zone="UTC"),
160203
"boolean": dtypes.Boolean(),
161204
"interval": dtypes.Duration(),
162205
"decimal": dtypes.Decimal(),

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def duckdb_lazy_constructor(obj: Data) -> duckdb.DuckDBPyRelation:
136136
import duckdb
137137
import polars as pl
138138

139+
duckdb.sql("""set timezone = 'UTC'""")
140+
139141
_df = pl.LazyFrame(obj)
140142
return duckdb.table("_df")
141143

tests/dtypes_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,28 @@ def test_enum_repr() -> None:
457457
def test_enum_hash() -> None:
458458
assert nw.Enum(["a", "b"]) in {nw.Enum(["a", "b"])}
459459
assert nw.Enum(["a", "b"]) not in {nw.Enum(["a", "b", "c"])}
460+
461+
462+
def test_datetime_w_tz_duckdb() -> None:
463+
pytest.importorskip("duckdb")
464+
pytest.importorskip("zoneinfo")
465+
import duckdb
466+
467+
duckdb.sql("""set timezone = 'Europe/Amsterdam'""")
468+
df = nw.from_native(
469+
duckdb.sql("""select * from values (timestamptz '2020-01-01')df(a)""")
470+
)
471+
result = df.collect_schema()
472+
assert result["a"] == nw.Datetime("us", "Europe/Amsterdam")
473+
duckdb.sql("""set timezone = 'Asia/Kathmandu'""")
474+
result = df.collect_schema()
475+
assert result["a"] == nw.Datetime("us", "Asia/Kathmandu")
476+
477+
df = nw.from_native(
478+
duckdb.sql(
479+
"""select * from values (timestamptz '2020-01-01', [[timestamptz '2020-01-02']])df(a,b)"""
480+
)
481+
)
482+
result = df.collect_schema()
483+
assert result["a"] == nw.Datetime("us", "Asia/Kathmandu")
484+
assert result["b"] == nw.List(nw.List(nw.Datetime("us", "Asia/Kathmandu")))

tests/expr_and_series/dt/convert_time_zone_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,22 @@ def test_convert_time_zone_to_none_series(constructor_eager: ConstructorEager) -
134134
df = nw.from_native(constructor_eager(data))
135135
with pytest.raises(TypeError, match="Target `time_zone` cannot be `None`"):
136136
df["a"].dt.convert_time_zone(None) # type: ignore[arg-type]
137+
138+
139+
def test_convert_time_zone_to_connection_tz_duckdb() -> None:
140+
pytest.importorskip("duckdb")
141+
pytest.importorskip("zoneinfo")
142+
import duckdb
143+
from zoneinfo import ZoneInfo
144+
145+
duckdb.sql("set timezone = 'Asia/Kolkata'")
146+
rel = duckdb.sql("""select * from values (timestamptz '2020-01-01') df(a)""")
147+
result = nw.from_native(rel).with_columns(
148+
nw.col("a").dt.convert_time_zone("Asia/Kolkata")
149+
)
150+
expected = {"a": [datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kolkata"))]}
151+
assert_equal_data(result, expected)
152+
with pytest.raises(NotImplementedError):
153+
result = nw.from_native(rel).with_columns(
154+
nw.col("a").dt.convert_time_zone("Asia/Kathmandu")
155+
)

tests/expr_and_series/dt/replace_time_zone_test.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def test_replace_time_zone(
2828
or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,))
2929
):
3030
pytest.skip()
31-
if any(x in str(constructor) for x in ("cudf", "duckdb", "pyspark", "ibis")):
31+
32+
if any(x in str(constructor) for x in ("cudf", "pyspark", "ibis", "duckdb")):
3233
request.applymarker(pytest.mark.xfail)
3334
data = {
3435
"a": [
@@ -129,3 +130,22 @@ def test_replace_time_zone_none_series(constructor_eager: ConstructorEager) -> N
129130
result_str = result.select(df["a"].dt.to_string("%Y-%m-%dT%H:%M"))
130131
expected = {"a": ["2020-01-01T00:00", "2020-01-02T00:00"]}
131132
assert_equal_data(result_str, expected)
133+
134+
135+
def test_replace_time_zone_to_connection_tz_duckdb() -> None:
136+
pytest.importorskip("duckdb")
137+
pytest.importorskip("zoneinfo")
138+
import duckdb
139+
from zoneinfo import ZoneInfo
140+
141+
duckdb.sql("set timezone = 'Asia/Kolkata'")
142+
rel = duckdb.sql("""select * from values (timestamptz '2020-01-01') df(a)""")
143+
result = nw.from_native(rel).with_columns(
144+
nw.col("a").dt.replace_time_zone("Asia/Kolkata")
145+
)
146+
expected = {"a": [datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kolkata"))]}
147+
assert_equal_data(result, expected)
148+
with pytest.raises(NotImplementedError):
149+
result = nw.from_native(rel).with_columns(
150+
nw.col("a").dt.replace_time_zone("Asia/Kathmandu")
151+
)

0 commit comments

Comments
 (0)