Skip to content

Commit 69de731

Browse files
authored
refactor: Shrink Polars*Namespace(s) (#2897)
1 parent 7d5a457 commit 69de731

File tree

8 files changed

+232
-208
lines changed

8 files changed

+232
-208
lines changed

narwhals/_compliant/series.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,15 @@ def hist_from_bin_count(
285285
...
286286

287287
@property
288-
def str(self) -> Any: ...
288+
def str(self) -> StringNamespace[Self]: ...
289289
@property
290-
def dt(self) -> Any: ...
290+
def dt(self) -> DateTimeNamespace[Self]: ...
291291
@property
292-
def cat(self) -> Any: ...
292+
def cat(self) -> CatNamespace[Self]: ...
293293
@property
294-
def list(self) -> Any: ...
294+
def list(self) -> ListNamespace[Self]: ...
295295
@property
296-
def struct(self) -> Any: ...
296+
def struct(self) -> StructNamespace[Self]: ...
297297

298298

299299
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):

narwhals/_polars/dataframe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,9 @@ def __getitem__( # noqa: C901, PLR0912
407407
native = native.select(
408408
self.columns[slice(columns.start, columns.stop, columns.step)]
409409
)
410-
elif is_compliant_series(columns):
410+
# NOTE: `mypy` loses track of `PolarsSeries` when `is_compliant_series` is used here
411+
# `pyright` is fine
412+
elif isinstance(columns, PolarsSeries):
411413
native = native[:, columns.native.to_list()]
412414
else:
413415
native = native[:, columns]

narwhals/_polars/expr.py

Lines changed: 47 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44

55
import polars as pl
66

7-
from narwhals._duration import Interval
87
from narwhals._polars.utils import (
8+
PolarsAnyNamespace,
9+
PolarsCatNamespace,
10+
PolarsDateTimeNamespace,
11+
PolarsListNamespace,
12+
PolarsStringNamespace,
13+
PolarsStructNamespace,
914
extract_args_kwargs,
1015
extract_native,
1116
narwhals_to_native_dtype,
@@ -17,10 +22,11 @@
1722

1823
from typing_extensions import Self
1924

25+
from narwhals._compliant.typing import EvalNames
2026
from narwhals._expression_parsing import ExprKind, ExprMetadata
21-
from narwhals._polars.dataframe import Method
27+
from narwhals._polars.dataframe import Method, PolarsDataFrame
2228
from narwhals._polars.namespace import PolarsNamespace
23-
from narwhals._utils import Version
29+
from narwhals._utils import Version, _LimitedContext
2430
from narwhals.typing import IntoDType
2531

2632

@@ -256,9 +262,26 @@ def struct(self) -> PolarsExprStructNamespace:
256262
_evaluate_output_names: Any
257263
_is_multi_output_unnamed: Any
258264
__call__: Any
259-
from_column_names: Any
260-
from_column_indices: Any
261-
_eval_names_indices: Any
265+
266+
# CompliantExpr + builtin descriptor
267+
# TODO @dangotbanned: Remove in #2713
268+
@classmethod
269+
def from_column_names(
270+
cls,
271+
evaluate_column_names: EvalNames[PolarsDataFrame],
272+
/,
273+
*,
274+
context: _LimitedContext,
275+
) -> Self:
276+
raise NotImplementedError
277+
278+
@classmethod
279+
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
280+
raise NotImplementedError
281+
282+
@staticmethod
283+
def _eval_names_indices(indices: Sequence[int], /) -> EvalNames[PolarsDataFrame]:
284+
raise NotImplementedError
262285

263286
# Polars
264287
abs: Method[Self]
@@ -311,7 +334,7 @@ def struct(self) -> PolarsExprStructNamespace:
311334
var: Method[Self]
312335

313336

314-
class PolarsExprNamespace:
337+
class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]):
315338
def __init__(self, expr: PolarsExpr) -> None:
316339
self._expr = expr
317340

@@ -324,49 +347,14 @@ def native(self) -> pl.Expr:
324347
return self._expr.native
325348

326349

327-
class PolarsExprDateTimeNamespace(PolarsExprNamespace):
328-
def truncate(self, every: str) -> PolarsExpr:
329-
Interval.parse(every) # Ensure consistent error message is raised.
330-
return self.compliant._with_native(self.native.dt.truncate(every))
331-
332-
def offset_by(self, by: str) -> PolarsExpr:
333-
# Ensure consistent error message is raised.
334-
Interval.parse_no_constraints(by)
335-
return self.compliant._with_native(self.native.dt.offset_by(by))
336-
337-
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
338-
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
339-
pos, kwds = extract_args_kwargs(args, kwargs)
340-
return self.compliant._with_native(
341-
getattr(self.native.dt, attr)(*pos, **kwds)
342-
)
350+
class PolarsExprDateTimeNamespace(
351+
PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr]
352+
): ...
343353

344-
return func
345354

346-
to_string: Method[PolarsExpr]
347-
replace_time_zone: Method[PolarsExpr]
348-
convert_time_zone: Method[PolarsExpr]
349-
timestamp: Method[PolarsExpr]
350-
date: Method[PolarsExpr]
351-
year: Method[PolarsExpr]
352-
month: Method[PolarsExpr]
353-
day: Method[PolarsExpr]
354-
hour: Method[PolarsExpr]
355-
minute: Method[PolarsExpr]
356-
second: Method[PolarsExpr]
357-
millisecond: Method[PolarsExpr]
358-
microsecond: Method[PolarsExpr]
359-
nanosecond: Method[PolarsExpr]
360-
ordinal_day: Method[PolarsExpr]
361-
weekday: Method[PolarsExpr]
362-
total_minutes: Method[PolarsExpr]
363-
total_seconds: Method[PolarsExpr]
364-
total_milliseconds: Method[PolarsExpr]
365-
total_microseconds: Method[PolarsExpr]
366-
total_nanoseconds: Method[PolarsExpr]
367-
368-
369-
class PolarsExprStringNamespace(PolarsExprNamespace):
355+
class PolarsExprStringNamespace(
356+
PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr]
357+
):
370358
def zfill(self, width: int) -> PolarsExpr:
371359
backend_version = self.compliant._backend_version
372360
native_result = self.native.str.zfill(width)
@@ -395,53 +383,14 @@ def zfill(self, width: int) -> PolarsExpr:
395383

396384
return self.compliant._with_native(native_result)
397385

398-
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
399-
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
400-
pos, kwds = extract_args_kwargs(args, kwargs)
401-
return self.compliant._with_native(
402-
getattr(self.native.str, attr)(*pos, **kwds)
403-
)
404-
405-
return func
406-
407-
len_chars: Method[PolarsExpr]
408-
replace: Method[PolarsExpr]
409-
replace_all: Method[PolarsExpr]
410-
strip_chars: Method[PolarsExpr]
411-
starts_with: Method[PolarsExpr]
412-
ends_with: Method[PolarsExpr]
413-
contains: Method[PolarsExpr]
414-
slice: Method[PolarsExpr]
415-
split: Method[PolarsExpr]
416-
to_date: Method[PolarsExpr]
417-
to_datetime: Method[PolarsExpr]
418-
to_lowercase: Method[PolarsExpr]
419-
to_uppercase: Method[PolarsExpr]
420-
421386

422-
class PolarsExprCatNamespace(PolarsExprNamespace):
423-
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
424-
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
425-
pos, kwds = extract_args_kwargs(args, kwargs)
426-
return self.compliant._with_native(
427-
getattr(self.native.cat, attr)(*pos, **kwds)
428-
)
429-
430-
return func
431-
432-
get_categories: Method[PolarsExpr]
387+
class PolarsExprCatNamespace(
388+
PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]
389+
): ...
433390

434391

435392
class PolarsExprNameNamespace(PolarsExprNamespace):
436-
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
437-
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
438-
pos, kwds = extract_args_kwargs(args, kwargs)
439-
return self.compliant._with_native(
440-
getattr(self.native.name, attr)(*pos, **kwds)
441-
)
442-
443-
return func
444-
393+
_accessor = "name"
445394
keep: Method[PolarsExpr]
446395
map: Method[PolarsExpr]
447396
prefix: Method[PolarsExpr]
@@ -450,9 +399,11 @@ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
450399
to_uppercase: Method[PolarsExpr]
451400

452401

453-
class PolarsExprListNamespace(PolarsExprNamespace):
402+
class PolarsExprListNamespace(
403+
PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr]
404+
):
454405
def len(self) -> PolarsExpr:
455-
native_expr = self.compliant._native_expr
406+
native_expr = self.native
456407
native_result = native_expr.list.len()
457408

458409
if self.compliant._backend_version < (1, 16): # pragma: no cover
@@ -464,25 +415,7 @@ def len(self) -> PolarsExpr:
464415

465416
return self.compliant._with_native(native_result)
466417

467-
# TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added
468-
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover
469-
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
470-
pos, kwds = extract_args_kwargs(args, kwargs)
471-
return self.compliant._with_native(
472-
getattr(self.native.list, attr)(*pos, **kwds)
473-
)
474-
475-
return func
476-
477-
478-
class PolarsExprStructNamespace(PolarsExprNamespace):
479-
def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover
480-
def func(*args: Any, **kwargs: Any) -> PolarsExpr:
481-
pos, kwds = extract_args_kwargs(args, kwargs)
482-
return self.compliant._with_native(
483-
getattr(self.native.struct, attr)(*pos, **kwds)
484-
)
485-
486-
return func
487418

488-
field: Method[PolarsExpr]
419+
class PolarsExprStructNamespace(
420+
PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr]
421+
): ...

narwhals/_polars/namespace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ class PolarsNamespace:
3333
min_horizontal: Method[PolarsExpr]
3434
max_horizontal: Method[PolarsExpr]
3535

36-
# NOTE: `pyright` accepts, `mypy` doesn't highlight the issue
37-
# error: Type argument "PolarsExpr" of "CompliantWhen" must be a subtype of "CompliantExpr[Any, Any]"
38-
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] # type: ignore[type-var]
36+
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]]
3937

4038
_implementation = Implementation.POLARS
4139

0 commit comments

Comments
 (0)