Skip to content

Commit e1526a8

Browse files
authored
feat(typing): Add polars literal aliases (#2352)
* feat(typing): Add some `polars` aliases Towards #2337 * refactor(typing): Replace literals in impls * feat(typing): Add `LazyUniqueKeepStrategy` Not from `polars`, but describes the subset we support for `LazyFrame.unique` * docs(typing): Add trailing docstrings Copied over unchanged, but seems they could do with standardising * docs: Update `typing.md` - Following the existing pattern of `SizeUnit`, `TimeUnit` - They appear in docs, but **not** in `nw.typing.__all__` * refactor(typing): Replace literals in public api * refactor(typing): Update `v1` * docs: Use more consistent markdown - Use `-` for bullet lists - Wrap literals in `"..."` - Single space after `:` - Titlecase for sentences - Inline code for parameter names (e.g. ``on``) - Wrap lines at ~80 characters
1 parent 72f69b0 commit e1526a8

File tree

29 files changed

+232
-217
lines changed

29 files changed

+232
-217
lines changed

docs/api-reference/typing.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ Narwhals comes fully statically typed. In addition to `nw.DataFrame`, `nw.Expr`,
1919
- IntoSeriesT
2020
- SizeUnit
2121
- TimeUnit
22+
- AsofJoinStrategy
23+
- ClosedInterval
24+
- ConcatMethod
25+
- FillNullStrategy
26+
- JoinStrategy
27+
- PivotAgg
28+
- RankMethod
29+
- RollingInterpolationMethod
30+
- UniqueKeepStrategy
31+
- LazyUniqueKeepStrategy
2232
show_source: false
2333
show_bases: false
2434

narwhals/_arrow/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@
5959
from narwhals.schema import Schema
6060
from narwhals.typing import CompliantDataFrame
6161
from narwhals.typing import CompliantLazyFrame
62+
from narwhals.typing import JoinStrategy
6263
from narwhals.typing import SizeUnit
64+
from narwhals.typing import UniqueKeepStrategy
6365
from narwhals.typing import _1DArray
6466
from narwhals.typing import _2DArray
6567
from narwhals.utils import Version
@@ -451,7 +453,7 @@ def join(
451453
self: Self,
452454
other: Self,
453455
*,
454-
how: Literal["inner", "left", "full", "cross", "semi", "anti"],
456+
how: JoinStrategy,
455457
left_on: Sequence[str] | None,
456458
right_on: Sequence[str] | None,
457459
suffix: str,
@@ -758,7 +760,7 @@ def unique(
758760
self: ArrowDataFrame,
759761
subset: Sequence[str] | None,
760762
*,
761-
keep: Literal["any", "first", "last", "none"],
763+
keep: UniqueKeepStrategy,
762764
maintain_order: bool | None = None,
763765
) -> ArrowDataFrame:
764766
# The param `maintain_order` is only here for compatibility with the Polars API

narwhals/_arrow/expr.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5-
from typing import Literal
65
from typing import Sequence
76

87
import pyarrow.compute as pc
@@ -25,6 +24,7 @@
2524
from narwhals._compliant.typing import EvalNames
2625
from narwhals._compliant.typing import EvalSeries
2726
from narwhals._expression_parsing import ExprMetadata
27+
from narwhals.typing import RankMethod
2828
from narwhals.utils import Version
2929
from narwhals.utils import _FullContext
3030

@@ -208,12 +208,7 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
208208
def cum_prod(self: Self, *, reverse: bool) -> Self:
209209
return self._reuse_series("cum_prod", reverse=reverse)
210210

211-
def rank(
212-
self: Self,
213-
method: Literal["average", "min", "max", "dense", "ordinal"],
214-
*,
215-
descending: bool,
216-
) -> Self:
211+
def rank(self, method: RankMethod, *, descending: bool) -> Self:
217212
return self._reuse_series("rank", method=method, descending=descending)
218213

219214
ewm_mean = not_implemented()

narwhals/_arrow/namespace.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from narwhals._arrow.typing import ArrowChunkedArray
3535
from narwhals._arrow.typing import Incomplete
3636
from narwhals.dtypes import DType
37+
from narwhals.typing import ConcatMethod
3738
from narwhals.utils import Version
3839

3940

@@ -211,10 +212,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
211212
)
212213

213214
def concat(
214-
self: Self,
215-
items: Iterable[ArrowDataFrame],
216-
*,
217-
how: Literal["horizontal", "vertical", "diagonal"],
215+
self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod
218216
) -> ArrowDataFrame:
219217
dfs = [item.native for item in items]
220218

narwhals/_arrow/series.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any
55
from typing import Iterable
66
from typing import Iterator
7-
from typing import Literal
87
from typing import Mapping
98
from typing import Sequence
109
from typing import cast
@@ -56,7 +55,11 @@
5655
from narwhals._arrow.typing import _AsPyType
5756
from narwhals._arrow.typing import _BasicDataType
5857
from narwhals.dtypes import DType
58+
from narwhals.typing import ClosedInterval
59+
from narwhals.typing import FillNullStrategy
5960
from narwhals.typing import Into1DArray
61+
from narwhals.typing import RankMethod
62+
from narwhals.typing import RollingInterpolationMethod
6063
from narwhals.typing import _1DArray
6164
from narwhals.typing import _2DArray
6265
from narwhals.utils import Version
@@ -499,10 +502,7 @@ def all(self: Self, *, _return_py_scalar: bool = True) -> bool:
499502
)
500503

501504
def is_between(
502-
self: Self,
503-
lower_bound: Any,
504-
upper_bound: Any,
505-
closed: Literal["left", "right", "none", "both"],
505+
self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval
506506
) -> Self:
507507
_, lower_bound = extract_native(self, lower_bound)
508508
_, upper_bound = extract_native(self, upper_bound)
@@ -636,17 +636,14 @@ def sample(
636636
return self._with_native(self.native.take(mask))
637637

638638
def fill_null(
639-
self: Self,
640-
value: Any | None,
641-
strategy: Literal["forward", "backward"] | None,
642-
limit: int | None,
639+
self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None
643640
) -> Self:
644641
import numpy as np # ignore-banned-import
645642

646643
def fill_aux(
647644
arr: ArrowArray | ArrowChunkedArray,
648645
limit: int,
649-
direction: Literal["forward", "backward"] | None = None,
646+
direction: FillNullStrategy | None = None,
650647
) -> ArrowArray:
651648
# this algorithm first finds the indices of the valid values to fill all the null value positions
652649
# then it calculates the distance of each new index and the original index
@@ -812,9 +809,9 @@ def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFram
812809
).simple_select(*output_order)
813810

814811
def quantile(
815-
self: Self,
812+
self,
816813
quantile: float,
817-
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
814+
interpolation: RollingInterpolationMethod,
818815
*,
819816
_return_py_scalar: bool = True,
820817
) -> float:
@@ -1028,12 +1025,7 @@ def rolling_std(
10281025
** 0.5
10291026
)
10301027

1031-
def rank(
1032-
self: Self,
1033-
method: Literal["average", "min", "max", "dense", "ordinal"],
1034-
*,
1035-
descending: bool,
1036-
) -> Self:
1028+
def rank(self, method: RankMethod, *, descending: bool) -> Self:
10371029
if method == "average":
10381030
msg = (
10391031
"`rank` with `method='average' is not supported for pyarrow backend. "

narwhals/_compliant/dataframe.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
from narwhals._translate import IntoArrowTable
4141
from narwhals.dtypes import DType
4242
from narwhals.schema import Schema
43+
from narwhals.typing import AsofJoinStrategy
44+
from narwhals.typing import JoinStrategy
45+
from narwhals.typing import LazyUniqueKeepStrategy
4346
from narwhals.typing import SizeUnit
47+
from narwhals.typing import UniqueKeepStrategy
4448
from narwhals.typing import _2DArray
4549
from narwhals.utils import Implementation
4650
from narwhals.utils import _FullContext
@@ -144,7 +148,7 @@ def join(
144148
self: Self,
145149
other: Self,
146150
*,
147-
how: Literal["inner", "left", "full", "cross", "semi", "anti"],
151+
how: JoinStrategy,
148152
left_on: Sequence[str] | None,
149153
right_on: Sequence[str] | None,
150154
suffix: str,
@@ -157,7 +161,7 @@ def join_asof(
157161
right_on: str | None,
158162
by_left: Sequence[str] | None,
159163
by_right: Sequence[str] | None,
160-
strategy: Literal["backward", "forward", "nearest"],
164+
strategy: AsofJoinStrategy,
161165
suffix: str,
162166
) -> Self: ...
163167
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrame[Any, Any]: ...
@@ -193,7 +197,7 @@ def unique(
193197
self,
194198
subset: Sequence[str] | None,
195199
*,
196-
keep: Literal["any", "first", "last", "none"],
200+
keep: UniqueKeepStrategy,
197201
maintain_order: bool | None = None,
198202
) -> Self: ...
199203
def unpivot(
@@ -284,7 +288,7 @@ def join_asof(
284288
right_on: str | None,
285289
by_left: Sequence[str] | None,
286290
by_right: Sequence[str] | None,
287-
strategy: Literal["backward", "forward", "nearest"],
291+
strategy: AsofJoinStrategy,
288292
suffix: str,
289293
) -> Self: ...
290294
def rename(self, mapping: Mapping[str, str]) -> Self: ...
@@ -295,10 +299,7 @@ def sort(
295299
@deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.")
296300
def tail(self, n: int) -> Self: ...
297301
def unique(
298-
self,
299-
subset: Sequence[str] | None,
300-
*,
301-
keep: Literal["any", "none"],
302+
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
302303
) -> Self: ...
303304
def unpivot(
304305
self,

narwhals/_compliant/expr.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
from narwhals._expression_parsing import ExprKind
6565
from narwhals._expression_parsing import ExprMetadata
6666
from narwhals.dtypes import DType
67+
from narwhals.typing import FillNullStrategy
68+
from narwhals.typing import RankMethod
69+
from narwhals.typing import RollingInterpolationMethod
6770
from narwhals.typing import TimeUnit
6871
from narwhals.utils import Implementation
6972
from narwhals.utils import Version
@@ -131,10 +134,7 @@ def n_unique(self) -> Self: ...
131134
def null_count(self) -> Self: ...
132135
def drop_nulls(self) -> Self: ...
133136
def fill_null(
134-
self,
135-
value: Any | None,
136-
strategy: Literal["forward", "backward"] | None,
137-
limit: int | None,
137+
self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None
138138
) -> Self: ...
139139
def diff(self) -> Self: ...
140140
def unique(self) -> Self: ...
@@ -156,12 +156,7 @@ def cum_max(self, *, reverse: bool) -> Self: ...
156156
def cum_prod(self, *, reverse: bool) -> Self: ...
157157
def is_in(self, other: Any) -> Self: ...
158158
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
159-
def rank(
160-
self,
161-
method: Literal["average", "min", "max", "dense", "ordinal"],
162-
*,
163-
descending: bool,
164-
) -> Self: ...
159+
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
165160
def replace_strict(
166161
self,
167162
old: Sequence[Any] | Mapping[Any, Any],
@@ -181,9 +176,7 @@ def sample(
181176
seed: int | None,
182177
) -> Self: ...
183178
def quantile(
184-
self,
185-
quantile: float,
186-
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
179+
self, quantile: float, interpolation: RollingInterpolationMethod
187180
) -> Self: ...
188181
def map_batches(
189182
self,
@@ -657,10 +650,7 @@ def is_nan(self) -> Self:
657650
return self._reuse_series("is_nan")
658651

659652
def fill_null(
660-
self,
661-
value: Any | None,
662-
strategy: Literal["forward", "backward"] | None,
663-
limit: int | None,
653+
self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None
664654
) -> Self:
665655
return self._reuse_series(
666656
"fill_null", value=value, strategy=strategy, limit=limit
@@ -746,9 +736,7 @@ def is_last_distinct(self) -> Self:
746736
return self._reuse_series("is_last_distinct")
747737

748738
def quantile(
749-
self,
750-
quantile: float,
751-
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
739+
self, quantile: float, interpolation: RollingInterpolationMethod
752740
) -> Self:
753741
return self._reuse_series(
754742
"quantile",

narwhals/_compliant/namespace.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any
66
from typing import Container
77
from typing import Iterable
8-
from typing import Literal
98
from typing import Mapping
109
from typing import Protocol
1110
from typing import Sequence
@@ -35,6 +34,7 @@
3534
from narwhals._compliant.when_then import EagerWhen
3635
from narwhals.dtypes import DType
3736
from narwhals.schema import Schema
37+
from narwhals.typing import ConcatMethod
3838
from narwhals.typing import Into1DArray
3939
from narwhals.typing import _2DArray
4040
from narwhals.utils import Implementation
@@ -75,10 +75,7 @@ def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
7575
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
7676
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
7777
def concat(
78-
self,
79-
items: Iterable[CompliantFrameT],
80-
*,
81-
how: Literal["horizontal", "vertical", "diagonal"],
78+
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
8279
) -> CompliantFrameT: ...
8380
def when(
8481
self, predicate: CompliantExprT

narwhals/_compliant/series.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Generic
66
from typing import Iterable
77
from typing import Iterator
8-
from typing import Literal
98
from typing import Mapping
109
from typing import Protocol
1110
from typing import Sequence
@@ -40,7 +39,11 @@
4039
from narwhals._compliant.namespace import CompliantNamespace
4140
from narwhals._compliant.namespace import EagerNamespace
4241
from narwhals.dtypes import DType
42+
from narwhals.typing import ClosedInterval
43+
from narwhals.typing import FillNullStrategy
4344
from narwhals.typing import Into1DArray
45+
from narwhals.typing import RankMethod
46+
from narwhals.typing import RollingInterpolationMethod
4447
from narwhals.typing import _1DArray
4548
from narwhals.utils import Implementation
4649
from narwhals.utils import Version
@@ -152,10 +155,7 @@ def ewm_mean(
152155
ignore_nulls: bool,
153156
) -> Self: ...
154157
def fill_null(
155-
self,
156-
value: Any | None,
157-
strategy: Literal["forward", "backward"] | None,
158-
limit: int | None,
158+
self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None
159159
) -> Self: ...
160160
def filter(self, predicate: Any) -> Self: ...
161161
def gather_every(self, n: int, offset: int) -> Self: ...
@@ -169,10 +169,7 @@ def hist(
169169
) -> CompliantDataFrame[Self, Any, Any]: ...
170170
def head(self, n: int) -> Self: ...
171171
def is_between(
172-
self,
173-
lower_bound: Any,
174-
upper_bound: Any,
175-
closed: Literal["left", "right", "none", "both"],
172+
self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval
176173
) -> Self: ...
177174
def is_finite(self) -> Self: ...
178175
def is_first_distinct(self) -> Self: ...
@@ -192,16 +189,9 @@ def mode(self) -> Self: ...
192189
def n_unique(self) -> int: ...
193190
def null_count(self) -> int: ...
194191
def quantile(
195-
self,
196-
quantile: float,
197-
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
192+
self, quantile: float, interpolation: RollingInterpolationMethod
198193
) -> float: ...
199-
def rank(
200-
self,
201-
method: Literal["average", "min", "max", "dense", "ordinal"],
202-
*,
203-
descending: bool,
204-
) -> Self: ...
194+
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
205195
def replace_strict(
206196
self,
207197
old: Sequence[Any] | Mapping[Any, Any],

0 commit comments

Comments
 (0)