Skip to content

Commit 4040065

Browse files
chore: use less sql for duckdb, do extra validation (#2331)
* chore: use less sql for duckdb * join_asof too * coverage * rename assert_type -> ensure_type * docs: Link to why ignore * simplify predicate * walrus * refactor: Be a little fancier in `ensure_type` --------- Co-authored-by: dangotbanned <[email protected]>
1 parent 70218b0 commit 4040065

File tree

4 files changed

+52
-33
lines changed

4 files changed

+52
-33
lines changed

narwhals/_duckdb/dataframe.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -265,18 +265,21 @@ def join(
265265
if self._backend_version < (1, 1, 4):
266266
msg = f"DuckDB>=1.1.4 is required for cross-join, found version: {self._backend_version}"
267267
raise NotImplementedError(msg)
268-
rel = self.native.set_alias("lhs").cross(
269-
other.native.set_alias("rhs")
270-
) # pragma: no cover
268+
rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs"))
271269
else:
272270
# help mypy
273271
assert left_on is not None # noqa: S101
274272
assert right_on is not None # noqa: S101
275-
condition = " and ".join(
276-
f'lhs."{left}" = rhs."{right}"' for left, right in zip(left_on, right_on)
273+
it = (
274+
col(f'lhs."{left}"') == col(f'rhs."{right}"')
275+
for left, right in zip(left_on, right_on)
277276
)
277+
condition: duckdb.Expression = reduce(and_, it)
278278
rel = self.native.set_alias("lhs").join(
279-
other.native.set_alias("rhs"), condition=condition, how=native_how
279+
other.native.set_alias("rhs"),
280+
# NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
281+
condition=condition, # type: ignore[arg-type, unused-ignore]
282+
how=native_how,
280283
)
281284

282285
if native_how in {"inner", "left", "cross", "outer"}:
@@ -310,21 +313,22 @@ def join_asof(
310313
) -> Self:
311314
lhs = self.native
312315
rhs = other.native
313-
conditions = []
316+
conditions: list[duckdb.Expression] = []
314317
if by_left is not None and by_right is not None:
315-
conditions += [
316-
f'lhs."{left}" = rhs."{right}"' for left, right in zip(by_left, by_right)
317-
]
318+
conditions.extend(
319+
col(f'lhs."{left}"') == col(f'rhs."{right}"')
320+
for left, right in zip(by_left, by_right)
321+
)
318322
else:
319323
by_left = by_right = []
320324
if strategy == "backward":
321-
conditions += [f'lhs."{left_on}" >= rhs."{right_on}"']
325+
conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"'))
322326
elif strategy == "forward":
323-
conditions += [f'lhs."{left_on}" <= rhs."{right_on}"']
327+
conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"'))
324328
else:
325329
msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
326330
raise NotImplementedError(msg)
327-
condition = " and ".join(conditions)
331+
condition: duckdb.Expression = reduce(and_, conditions)
328332
select = ["lhs.*"]
329333
for name in rhs.columns:
330334
if name in lhs.columns and (
@@ -333,6 +337,8 @@ def join_asof(
333337
select.append(f'rhs."{name}" as "{name}{suffix}"')
334338
elif right_on is None or name not in {right_on, *by_right}:
335339
select.append(f'"{name}"')
340+
# Replace with Python API call once
341+
# https://github.com/duckdb/duckdb/discussions/16947 is addressed.
336342
query = f"""
337343
SELECT {",".join(select)}
338344
FROM lhs
@@ -350,8 +356,7 @@ def collect_schema(self: Self) -> dict[str, DType]:
350356
def unique(
351357
self: Self, subset: Sequence[str] | None, *, keep: Literal["any", "none"]
352358
) -> Self:
353-
subset_ = subset if keep == "any" else (subset or self.columns)
354-
if subset_:
359+
if subset_ := subset if keep == "any" else (subset or self.columns):
355360
# Sanitise input
356361
if any(x not in self.columns for x in subset_):
357362
msg = f"Columns {set(subset_).difference(self.columns)} not found in {self.columns}."
@@ -366,13 +371,10 @@ def unique(
366371
count(*) over ({partition_by_sql}) as "{count_name}"
367372
from rel
368373
""" # noqa: S608
369-
if keep == "none":
370-
keep_condition = col(count_name) == lit(1)
371-
else:
372-
keep_condition = col(idx_name) == lit(1)
374+
name = count_name if keep == "none" else idx_name
373375
return self._with_native(
374376
duckdb.sql(query)
375-
.filter(keep_condition)
377+
.filter(col(name) == lit(1))
376378
.select(StarExpression(exclude=[count_name, idx_name]))
377379
)
378380
return self._with_native(self.native.unique(", ".join(self.columns)))
@@ -463,6 +465,8 @@ def unpivot(
463465

464466
unpivot_on = ", ".join(f'"{name}"' for name in on_)
465467
rel = self.native # noqa: F841
468+
# Replace with Python API once
469+
# https://github.com/duckdb/duckdb/discussions/16980 is addressed.
466470
query = f"""
467471
unpivot rel
468472
on {unpivot_on}

narwhals/_duckdb/expr.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
2121
from narwhals._duckdb.utils import WindowInputs
2222
from narwhals._duckdb.utils import col
23+
from narwhals._duckdb.utils import ensure_type
2324
from narwhals._duckdb.utils import generate_order_by_sql
2425
from narwhals._duckdb.utils import generate_partition_by_sql
2526
from narwhals._duckdb.utils import lit
@@ -109,6 +110,8 @@ def _rolling_window_func(
109110
min_samples: int,
110111
ddof: int | None = None,
111112
) -> WindowFunction:
113+
ensure_type(window_size, int, type(None))
114+
ensure_type(min_samples, int)
112115
supported_funcs = ["sum", "mean", "std", "var"]
113116
if center:
114117
half = (window_size - 1) // 2
@@ -139,11 +142,10 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
139142
else: # pragma: no cover
140143
msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}."
141144
raise ValueError(msg)
142-
sql = (
143-
f"case when count({window_inputs.expr}) over {window} >= {min_samples}"
144-
f"then {func_}({window_inputs.expr}) over {window} end"
145-
)
146-
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
145+
condition_sql = f"count({window_inputs.expr}) over {window} >= {min_samples}"
146+
condition = SQLExpression(condition_sql)
147+
value = SQLExpression(f"{func_}({window_inputs.expr}) over {window}")
148+
return when(condition, value)
147149

148150
return func
149151

@@ -540,6 +542,8 @@ def round(self: Self, decimals: int) -> Self:
540542
)
541543

542544
def shift(self, n: int) -> Self:
545+
ensure_type(n, int)
546+
543547
def func(window_inputs: WindowInputs) -> duckdb.Expression:
544548
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
545549
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
@@ -560,8 +564,8 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
560564
)
561565
else:
562566
partition_by_sql = f"partition by {window_inputs.expr}"
563-
sql = f"row_number() over({partition_by_sql} {order_by_sql}) == 1"
564-
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
567+
sql = f"row_number() over({partition_by_sql} {order_by_sql})"
568+
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
565569

566570
return self._with_window_function(func)
567571

@@ -575,8 +579,8 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
575579
)
576580
else:
577581
partition_by_sql = f"partition by {window_inputs.expr}"
578-
sql = f"row_number() over({partition_by_sql} {order_by_sql}) == 1"
579-
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
582+
sql = f"row_number() over({partition_by_sql} {order_by_sql})"
583+
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
580584

581585
return self._with_window_function(func)
582586

narwhals/_duckdb/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,12 @@ def generate_order_by_sql(*order_by: str, ascending: bool) -> str:
235235
else:
236236
by_sql = ", ".join([f'"{x}" desc nulls last' for x in order_by])
237237
return f"order by {by_sql}"
238+
239+
240+
def ensure_type(obj: Any, *valid_types: type[Any]) -> None:
241+
# Use this for extra (possibly redundant) validation in places where we
242+
# use SQLExpression, as an extra guard against unsafe inputs.
243+
if not isinstance(obj, valid_types): # pragma: no cover
244+
tp_names = " | ".join(tp.__name__ for tp in valid_types)
245+
msg = f"Expected {tp_names!r}, got: {type(obj).__name__!r}"
246+
raise TypeError(msg)

narwhals/_spark_like/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
4444
def native_to_narwhals_dtype(
4545
dtype: _NativeDType, version: Version, spark_types: ModuleType
46-
) -> DType: # pragma: no cover
46+
) -> DType:
4747
dtypes = import_dtypes_module(version=version)
4848
if TYPE_CHECKING:
4949
native = sqlframe_types
@@ -69,13 +69,15 @@ def native_to_narwhals_dtype(
6969
if isinstance(dtype, native.DateType):
7070
return dtypes.Date()
7171
if isinstance(dtype, native.TimestampNTZType):
72-
return dtypes.Datetime()
72+
# TODO(marco): cover this
73+
return dtypes.Datetime() # pragma: no cover
7374
if isinstance(dtype, native.TimestampType):
7475
# TODO(marco): is UTC correct, or should we be getting the connection timezone?
7576
# https://github.com/narwhals-dev/narwhals/issues/2165
7677
return dtypes.Datetime(time_zone="UTC")
7778
if isinstance(dtype, native.DecimalType):
78-
return dtypes.Decimal()
79+
# TODO(marco): cover this
80+
return dtypes.Decimal() # pragma: no cover
7981
if isinstance(dtype, native.ArrayType):
8082
return dtypes.List(
8183
inner=native_to_narwhals_dtype(
@@ -96,7 +98,7 @@ def native_to_narwhals_dtype(
9698
)
9799
if isinstance(dtype, native.BinaryType):
98100
return dtypes.Binary()
99-
return dtypes.Unknown()
101+
return dtypes.Unknown() # pragma: no cover
100102

101103

102104
def narwhals_to_native_dtype(

0 commit comments

Comments
 (0)