Skip to content

Commit 5ea6ee4

Browse files
fix(typing): Add overloads for ibis lit (#2972)
Co-authored-by: FBruzzesi <[email protected]>
1 parent 57a5dde commit 5ea6ee4

File tree

4 files changed

+39
-23
lines changed

4 files changed

+39
-23
lines changed

narwhals/_ibis/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,12 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
337337
if isinstance(descending, bool):
338338
descending = [descending for _ in range(len(by))]
339339

340-
sort_cols = []
340+
sort_cols: list[Any] = []
341341

342342
for i in range(len(by)):
343343
direction_fn = ibis.desc if descending[i] else ibis.asc
344344
col = direction_fn(by[i], nulls_first=not nulls_last)
345-
sort_cols.append(cast("ir.Column", col))
345+
sort_cols.append(col)
346346

347347
return self._with_native(self.native.order_by(*sort_cols))
348348

narwhals/_ibis/expr.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value:
227227
return expr.std(how="sample")
228228
n_samples = expr.count()
229229
std_pop = expr.std(how="pop")
230-
ddof_lit = cast("ir.IntegerScalar", ibis.literal(ddof))
230+
ddof_lit = lit(ddof)
231231
return std_pop * n_samples.sqrt() / (n_samples - ddof_lit).sqrt()
232232

233233
return self._with_callable(lambda expr: _std(expr, ddof))
@@ -240,7 +240,7 @@ def _var(expr: ir.NumericColumn, ddof: int) -> ir.Value:
240240
return expr.var(how="sample")
241241
n_samples = expr.count()
242242
var_pop = expr.var(how="pop")
243-
ddof_lit = cast("ir.IntegerScalar", ibis.literal(ddof))
243+
ddof_lit = lit(ddof)
244244
return var_pop * n_samples / (n_samples - ddof_lit)
245245

246246
return self._with_callable(lambda expr: _var(expr, ddof))
@@ -290,35 +290,33 @@ def is_unique(self) -> Self:
290290
)
291291

292292
def rank(self, method: RankMethod, *, descending: bool) -> Self:
293-
def _rank(expr: ir.Column) -> ir.Column:
293+
def _rank(expr: ir.Column) -> ir.Value:
294294
order_by = next(self._sort(expr, descending=[descending], nulls_last=[True]))
295295
window = ibis.window(order_by=order_by)
296296

297297
if method == "dense":
298298
rank_ = order_by.dense_rank()
299299
elif method == "ordinal":
300-
rank_ = cast("ir.IntegerColumn", ibis.row_number().over(window))
300+
rank_ = ibis.row_number().over(window)
301301
else:
302302
rank_ = order_by.rank()
303303

304304
# Ibis uses 0-based ranking. Add 1 to match polars 1-based rank.
305-
rank_ = rank_ + cast("ir.IntegerValue", lit(1))
305+
rank_ = rank_ + lit(1)
306306

307307
# For "max" and "average", adjust using the count of rows in the partition.
308308
if method == "max":
309309
# Define a window partitioned by expr (i.e. each distinct value)
310310
partition = ibis.window(group_by=[expr])
311-
cnt = cast("ir.IntegerValue", expr.count().over(partition))
312-
rank_ = rank_ + cnt - cast("ir.IntegerValue", lit(1))
311+
cnt = expr.count().over(partition)
312+
rank_ = rank_ + cnt - lit(1)
313313
elif method == "average":
314314
partition = ibis.window(group_by=[expr])
315-
cnt = cast("ir.IntegerValue", expr.count().over(partition))
316-
avg = cast(
317-
"ir.NumericValue", (cnt - cast("ir.IntegerScalar", lit(1))) / lit(2.0)
318-
)
315+
cnt = expr.count().over(partition)
316+
avg = cast("ir.NumericValue", (cnt - lit(1)) / lit(2.0))
319317
rank_ = rank_ + avg
320318

321-
return cast("ir.Column", ibis.cases((expr.notnull(), rank_)))
319+
return ibis.cases((expr.notnull(), rank_))
322320

323321
return self._with_callable(_rank)
324322

narwhals/_ibis/namespace.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import operator
44
from functools import reduce
55
from itertools import chain
6-
from typing import TYPE_CHECKING, Any, cast
6+
from typing import TYPE_CHECKING, Any
77

88
import ibis
99
import ibis.expr.types as ir
@@ -88,8 +88,7 @@ def func(df: IbisLazyFrame) -> list[ir.Value]:
8888
for col in cols_casted[1:]:
8989
result = result + separator + col
9090
else:
91-
sep = cast("ir.StringValue", lit(separator))
92-
result = sep.join(cols_casted)
91+
result = lit(separator).join(cols_casted)
9392

9493
return [result]
9594

narwhals/_ibis/utils.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import lru_cache
4-
from typing import TYPE_CHECKING, Any, Literal, cast
4+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
55

66
import ibis
77
import ibis.expr.datatypes as ibis_dtypes
@@ -23,8 +23,27 @@
2323
from narwhals.dtypes import DType
2424
from narwhals.typing import IntoDType, PythonLiteral
2525

26-
lit = ibis.literal
27-
"""Alias for `ibis.literal`."""
26+
Incomplete: TypeAlias = Any
27+
"""Marker for upstream issues."""
28+
29+
30+
@overload
31+
def lit(value: bool, dtype: None = ...) -> ir.BooleanScalar: ... # noqa: FBT001
32+
@overload
33+
def lit(value: int, dtype: None = ...) -> ir.IntegerScalar: ...
34+
@overload
35+
def lit(value: float, dtype: None = ...) -> ir.FloatingScalar: ...
36+
@overload
37+
def lit(value: str, dtype: None = ...) -> ir.StringScalar: ...
38+
@overload
39+
def lit(value: PythonLiteral | ir.Value, dtype: None = ...) -> ir.Scalar: ...
40+
@overload
41+
def lit(value: Any, dtype: Any) -> Incomplete: ...
42+
def lit(value: Any, dtype: Any | None = None) -> Incomplete:
43+
"""Alias for `ibis.literal`."""
44+
literal: Incomplete = ibis.literal
45+
return literal(value, dtype)
46+
2847

2948
BucketUnit: TypeAlias = Literal[
3049
"years",
@@ -231,11 +250,11 @@ def timedelta_to_ibis_interval(td: timedelta) -> ibis.expr.types.temporal.Interv
231250
def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
232251
# Workaround SQL vs Ibis differences.
233252
if name == "row_number":
234-
return ibis.row_number() + 1 # pyright: ignore[reportOperatorIssue]
253+
return ibis.row_number() + lit(1)
235254
if name == "least":
236-
return ibis.least(*args) # pyright: ignore[reportOperatorIssue]
255+
return ibis.least(*args)
237256
if name == "greatest":
238-
return ibis.greatest(*args) # pyright: ignore[reportOperatorIssue]
257+
return ibis.greatest(*args)
239258
expr = args[0]
240259
if name == "var_pop":
241260
return cast("ir.NumericColumn", expr).var(how="pop")

0 commit comments

Comments
 (0)