Skip to content

Commit c8cbd78

Browse files
committed
Merge remote-tracking branch 'upstream/main' into simp-pandas-group-by
2 parents 3065418 + a9f54b3 commit c8cbd78

39 files changed

+293
-291
lines changed

narwhals/_duckdb/expr.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
1313
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
1414
from narwhals._duckdb.utils import (
15+
DeferredTimeZone,
1516
F,
1617
col,
1718
lit,
@@ -678,11 +679,26 @@ def _fill_constant(expr: Expression, value: Any) -> Expression:
678679
return self._with_elementwise(_fill_constant, value=value)
679680

680681
def cast(self, dtype: IntoDType) -> Self:
681-
def func(expr: Expression) -> Expression:
682-
native_dtype = narwhals_to_native_dtype(dtype, self._version)
683-
return expr.cast(DuckDBPyType(native_dtype))
682+
def func(df: DuckDBLazyFrame) -> list[Expression]:
683+
tz = DeferredTimeZone(df.native)
684+
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
685+
return [expr.cast(DuckDBPyType(native_dtype)) for expr in self(df)]
686+
687+
def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
688+
tz = DeferredTimeZone(df.native)
689+
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
690+
return [
691+
expr.cast(DuckDBPyType(native_dtype))
692+
for expr in self.window_function(df, inputs)
693+
]
684694

685-
return self._with_elementwise(func)
695+
return self.__class__(
696+
func,
697+
window_f,
698+
evaluate_output_names=self._evaluate_output_names,
699+
alias_output_names=self._alias_output_names,
700+
version=self._version,
701+
)
686702

687703
@requires.backend_version((1, 3))
688704
def is_unique(self) -> Self:

narwhals/_duckdb/expr_dt.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,71 +29,71 @@ class DuckDBExprDateTimeNamespace(
2929
LazyExprNamespace["DuckDBExpr"], DateTimeNamespace["DuckDBExpr"]
3030
):
3131
def year(self) -> DuckDBExpr:
32-
return self.compliant._with_callable(lambda expr: F("year", expr))
32+
return self.compliant._with_elementwise(lambda expr: F("year", expr))
3333

3434
def month(self) -> DuckDBExpr:
35-
return self.compliant._with_callable(lambda expr: F("month", expr))
35+
return self.compliant._with_elementwise(lambda expr: F("month", expr))
3636

3737
def day(self) -> DuckDBExpr:
38-
return self.compliant._with_callable(lambda expr: F("day", expr))
38+
return self.compliant._with_elementwise(lambda expr: F("day", expr))
3939

4040
def hour(self) -> DuckDBExpr:
41-
return self.compliant._with_callable(lambda expr: F("hour", expr))
41+
return self.compliant._with_elementwise(lambda expr: F("hour", expr))
4242

4343
def minute(self) -> DuckDBExpr:
44-
return self.compliant._with_callable(lambda expr: F("minute", expr))
44+
return self.compliant._with_elementwise(lambda expr: F("minute", expr))
4545

4646
def second(self) -> DuckDBExpr:
47-
return self.compliant._with_callable(lambda expr: F("second", expr))
47+
return self.compliant._with_elementwise(lambda expr: F("second", expr))
4848

4949
def millisecond(self) -> DuckDBExpr:
50-
return self.compliant._with_callable(
50+
return self.compliant._with_elementwise(
5151
lambda expr: F("millisecond", expr) - F("second", expr) * lit(MS_PER_SECOND)
5252
)
5353

5454
def microsecond(self) -> DuckDBExpr:
55-
return self.compliant._with_callable(
55+
return self.compliant._with_elementwise(
5656
lambda expr: F("microsecond", expr) - F("second", expr) * lit(US_PER_SECOND)
5757
)
5858

5959
def nanosecond(self) -> DuckDBExpr:
60-
return self.compliant._with_callable(
60+
return self.compliant._with_elementwise(
6161
lambda expr: F("nanosecond", expr) - F("second", expr) * lit(NS_PER_SECOND)
6262
)
6363

6464
def to_string(self, format: str) -> DuckDBExpr:
65-
return self.compliant._with_callable(
65+
return self.compliant._with_elementwise(
6666
lambda expr: F("strftime", expr, lit(format))
6767
)
6868

6969
def weekday(self) -> DuckDBExpr:
70-
return self.compliant._with_callable(lambda expr: F("isodow", expr))
70+
return self.compliant._with_elementwise(lambda expr: F("isodow", expr))
7171

7272
def ordinal_day(self) -> DuckDBExpr:
73-
return self.compliant._with_callable(lambda expr: F("dayofyear", expr))
73+
return self.compliant._with_elementwise(lambda expr: F("dayofyear", expr))
7474

7575
def date(self) -> DuckDBExpr:
76-
return self.compliant._with_callable(lambda expr: expr.cast("date"))
76+
return self.compliant._with_elementwise(lambda expr: expr.cast("date"))
7777

7878
def total_minutes(self) -> DuckDBExpr:
79-
return self.compliant._with_callable(
79+
return self.compliant._with_elementwise(
8080
lambda expr: F("datepart", lit("minute"), expr)
8181
)
8282

8383
def total_seconds(self) -> DuckDBExpr:
84-
return self.compliant._with_callable(
84+
return self.compliant._with_elementwise(
8585
lambda expr: lit(SECONDS_PER_MINUTE) * F("datepart", lit("minute"), expr)
8686
+ F("datepart", lit("second"), expr)
8787
)
8888

8989
def total_milliseconds(self) -> DuckDBExpr:
90-
return self.compliant._with_callable(
90+
return self.compliant._with_elementwise(
9191
lambda expr: lit(MS_PER_MINUTE) * F("datepart", lit("minute"), expr)
9292
+ F("datepart", lit("millisecond"), expr)
9393
)
9494

9595
def total_microseconds(self) -> DuckDBExpr:
96-
return self.compliant._with_callable(
96+
return self.compliant._with_elementwise(
9797
lambda expr: lit(US_PER_MINUTE) * F("datepart", lit("minute"), expr)
9898
+ F("datepart", lit("microsecond"), expr)
9999
)
@@ -112,7 +112,7 @@ def truncate(self, every: str) -> DuckDBExpr:
112112
def _truncate(expr: Expression) -> Expression:
113113
return F("date_trunc", format, expr)
114114

115-
return self.compliant._with_callable(_truncate)
115+
return self.compliant._with_elementwise(_truncate)
116116

117117
def _no_op_time_zone(self, time_zone: str) -> DuckDBExpr:
118118
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
@@ -139,7 +139,7 @@ def convert_time_zone(self, time_zone: str) -> DuckDBExpr:
139139

140140
def replace_time_zone(self, time_zone: str | None) -> DuckDBExpr:
141141
if time_zone is None:
142-
return self.compliant._with_callable(lambda _input: _input.cast("timestamp"))
142+
return self.compliant._with_elementwise(lambda expr: expr.cast("timestamp"))
143143
else:
144144
return self._no_op_time_zone(time_zone)
145145

narwhals/_duckdb/expr_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ class DuckDBExprListNamespace(
1414
LazyExprNamespace["DuckDBExpr"], ListNamespace["DuckDBExpr"]
1515
):
1616
def len(self) -> DuckDBExpr:
17-
return self.compliant._with_callable(lambda expr: F("len", expr))
17+
return self.compliant._with_elementwise(lambda expr: F("len", expr))

narwhals/_duckdb/expr_str.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ class DuckDBExprStringNamespace(
1717
LazyExprNamespace["DuckDBExpr"], StringNamespace["DuckDBExpr"]
1818
):
1919
def starts_with(self, prefix: str) -> DuckDBExpr:
20-
return self.compliant._with_callable(
20+
return self.compliant._with_elementwise(
2121
lambda expr: F("starts_with", expr, lit(prefix))
2222
)
2323

2424
def ends_with(self, suffix: str) -> DuckDBExpr:
25-
return self.compliant._with_callable(
25+
return self.compliant._with_elementwise(
2626
lambda expr: F("ends_with", expr, lit(suffix))
2727
)
2828

@@ -32,7 +32,7 @@ def func(expr: Expression) -> Expression:
3232
return F("contains", expr, lit(pattern))
3333
return F("regexp_matches", expr, lit(pattern))
3434

35-
return self.compliant._with_callable(func)
35+
return self.compliant._with_elementwise(func)
3636

3737
def slice(self, offset: int, length: int | None) -> DuckDBExpr:
3838
def func(expr: Expression) -> Expression:
@@ -46,35 +46,37 @@ def func(expr: Expression) -> Expression:
4646
F("length", expr) if length is None else lit(length) + offset_lit,
4747
)
4848

49-
return self.compliant._with_callable(func)
49+
return self.compliant._with_elementwise(func)
5050

5151
def split(self, by: str) -> DuckDBExpr:
52-
return self.compliant._with_callable(lambda expr: F("str_split", expr, lit(by)))
52+
return self.compliant._with_elementwise(
53+
lambda expr: F("str_split", expr, lit(by))
54+
)
5355

5456
def len_chars(self) -> DuckDBExpr:
55-
return self.compliant._with_callable(lambda expr: F("length", expr))
57+
return self.compliant._with_elementwise(lambda expr: F("length", expr))
5658

5759
def to_lowercase(self) -> DuckDBExpr:
58-
return self.compliant._with_callable(lambda expr: F("lower", expr))
60+
return self.compliant._with_elementwise(lambda expr: F("lower", expr))
5961

6062
def to_uppercase(self) -> DuckDBExpr:
61-
return self.compliant._with_callable(lambda expr: F("upper", expr))
63+
return self.compliant._with_elementwise(lambda expr: F("upper", expr))
6264

6365
def strip_chars(self, characters: str | None) -> DuckDBExpr:
6466
import string
6567

66-
return self.compliant._with_callable(
68+
return self.compliant._with_elementwise(
6769
lambda expr: F(
6870
"trim", expr, lit(string.whitespace if characters is None else characters)
6971
)
7072
)
7173

7274
def replace_all(self, pattern: str, value: str, *, literal: bool) -> DuckDBExpr:
7375
if not literal:
74-
return self.compliant._with_callable(
76+
return self.compliant._with_elementwise(
7577
lambda expr: F("regexp_replace", expr, lit(pattern), lit(value), lit("g"))
7678
)
77-
return self.compliant._with_callable(
79+
return self.compliant._with_elementwise(
7880
lambda expr: F("replace", expr, lit(pattern), lit(value))
7981
)
8082

@@ -83,7 +85,7 @@ def to_datetime(self, format: str | None) -> DuckDBExpr:
8385
msg = "Cannot infer format with DuckDB backend, please specify `format` explicitly."
8486
raise NotImplementedError(msg)
8587

86-
return self.compliant._with_callable(
88+
return self.compliant._with_elementwise(
8789
lambda expr: F("strptime", expr, lit(format))
8890
)
8991

@@ -119,6 +121,8 @@ def func(expr: Expression) -> Expression:
119121
.otherwise(expr)
120122
)
121123

124+
# can't use `_with_elementwise` due to `when` operator.
125+
# TODO(unassigned): implement `window_func` like we do in `Expr.cast`
122126
return self.compliant._with_callable(func)
123127

124128
replace = not_implemented()

narwhals/_duckdb/expr_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ class DuckDBExprStructNamespace(
1414
LazyExprNamespace["DuckDBExpr"], StructNamespace["DuckDBExpr"]
1515
):
1616
def field(self, name: str) -> DuckDBExpr:
17-
return self.compliant._with_callable(
17+
return self.compliant._with_elementwise(
1818
lambda expr: F("struct_extract", expr, lit(name))
1919
).alias(name)

narwhals/_duckdb/namespace.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
from narwhals._duckdb.dataframe import DuckDBLazyFrame
1414
from narwhals._duckdb.expr import DuckDBExpr
1515
from narwhals._duckdb.selectors import DuckDBSelectorNamespace
16-
from narwhals._duckdb.utils import F, concat_str, lit, narwhals_to_native_dtype, when
16+
from narwhals._duckdb.utils import (
17+
DeferredTimeZone,
18+
F,
19+
concat_str,
20+
lit,
21+
narwhals_to_native_dtype,
22+
when,
23+
)
1724
from narwhals._expression_parsing import (
1825
combine_alias_output_names,
1926
combine_evaluate_output_names,
@@ -142,13 +149,11 @@ def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
142149
return DuckDBWhen.from_expr(predicate, context=self)
143150

144151
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr:
145-
def func(_df: DuckDBLazyFrame) -> list[Expression]:
152+
def func(df: DuckDBLazyFrame) -> list[Expression]:
153+
tz = DeferredTimeZone(df.native)
146154
if dtype is not None:
147-
return [
148-
lit(value).cast(
149-
narwhals_to_native_dtype(dtype, version=self._version) # type: ignore[arg-type]
150-
)
151-
]
155+
target = narwhals_to_native_dtype(dtype, self._version, tz)
156+
return [lit(value).cast(target)] # type: ignore[arg-type]
152157
return [lit(value)]
153158

154159
return self._expr(

narwhals/_duckdb/utils.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
"us": "microsecond",
3131
"ns": "nanosecond",
3232
}
33+
UNIT_TO_TIMESTAMPS = {
34+
"s": "TIMESTAMP_S",
35+
"ms": "TIMESTAMP_MS",
36+
"us": "TIMESTAMP",
37+
"ns": "TIMESTAMP_NS",
38+
}
3339

3440
col = duckdb.ColumnExpression
3541
"""Alias for `duckdb.ColumnExpression`."""
@@ -182,7 +188,10 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version)
182188
"float": dtypes.Float32(),
183189
"varchar": dtypes.String(),
184190
"date": dtypes.Date(),
191+
"timestamp_s": dtypes.Datetime("s"),
192+
"timestamp_ms": dtypes.Datetime("ms"),
185193
"timestamp": dtypes.Datetime(),
194+
"timestamp_ns": dtypes.Datetime("ns"),
186195
"boolean": dtypes.Boolean(),
187196
"interval": dtypes.Duration(),
188197
"decimal": dtypes.Decimal(),
@@ -191,7 +200,9 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version)
191200
}.get(duckdb_dtype_id, dtypes.Unknown())
192201

193202

194-
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> str: # noqa: C901, PLR0912, PLR0915
203+
def narwhals_to_native_dtype( # noqa: PLR0912,PLR0915,C901
204+
dtype: IntoDType, version: Version, deferred_time_zone: DeferredTimeZone
205+
) -> str:
195206
dtypes = version.dtypes
196207
if isinstance_or_issubclass(dtype, dtypes.Decimal):
197208
msg = "Casting to Decimal is not supported yet."
@@ -242,32 +253,40 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> str: # noqa
242253
raise ValueError(msg)
243254

244255
if isinstance_or_issubclass(dtype, dtypes.Datetime):
245-
_time_unit = dtype.time_unit
246-
_time_zone = dtype.time_zone
247-
msg = "todo"
248-
raise NotImplementedError(msg)
249-
if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover
250-
_time_unit = dtype.time_unit
251-
msg = "todo"
252-
raise NotImplementedError(msg)
253-
if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover
256+
tu = dtype.time_unit
257+
tz = dtype.time_zone
258+
if not tz:
259+
return UNIT_TO_TIMESTAMPS[tu]
260+
if tu != "us":
261+
msg = f"Only microsecond precision is supported for timezone-aware `Datetime` in DuckDB, got {tu} precision"
262+
raise ValueError(msg)
263+
if tz != (rel_tz := deferred_time_zone.time_zone): # pragma: no cover
264+
msg = f"Only the connection time zone {rel_tz} is supported, got: {tz}."
265+
raise ValueError(msg)
266+
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
267+
return "TIMESTAMPTZ" # pragma: no cover
268+
if isinstance_or_issubclass(dtype, dtypes.Duration):
269+
if (tu := dtype.time_unit) != "us": # pragma: no cover
270+
msg = f"Only microsecond-precision Duration is supported, got {tu} precision"
271+
return "INTERVAL"
272+
if isinstance_or_issubclass(dtype, dtypes.Date):
254273
return "DATE"
255274
if isinstance_or_issubclass(dtype, dtypes.List):
256-
inner = narwhals_to_native_dtype(dtype.inner, version)
275+
inner = narwhals_to_native_dtype(dtype.inner, version, deferred_time_zone)
257276
return f"{inner}[]"
258-
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
277+
if isinstance_or_issubclass(dtype, dtypes.Struct):
259278
inner = ", ".join(
260-
f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}'
279+
f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version, deferred_time_zone)}'
261280
for field in dtype.fields
262281
)
263282
return f"STRUCT({inner})"
264-
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
283+
if isinstance_or_issubclass(dtype, dtypes.Array):
265284
shape = dtype.shape
266285
duckdb_shape_fmt = "".join(f"[{item}]" for item in shape)
267286
inner_dtype: Any = dtype
268287
for _ in shape:
269288
inner_dtype = inner_dtype.inner
270-
duckdb_inner = narwhals_to_native_dtype(inner_dtype, version)
289+
duckdb_inner = narwhals_to_native_dtype(inner_dtype, version, deferred_time_zone)
271290
return f"{duckdb_inner}{duckdb_shape_fmt}"
272291
msg = f"Unknown dtype: {dtype}" # pragma: no cover
273292
raise AssertionError(msg)

0 commit comments

Comments
 (0)