|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from functools import partial |
3 | 4 | from typing import TYPE_CHECKING |
4 | 5 | from typing import Any |
| 6 | +from typing import Container |
| 7 | +from typing import Iterable |
| 8 | +from typing import Literal |
5 | 9 | from typing import Protocol |
6 | 10 |
|
| 11 | +from narwhals._compliant.typing import CompliantExprT |
7 | 12 | from narwhals._compliant.typing import CompliantFrameT |
8 | | -from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co |
9 | 13 | from narwhals._compliant.typing import EagerDataFrameT |
10 | 14 | from narwhals._compliant.typing import EagerExprT |
11 | 15 | from narwhals._compliant.typing import EagerSeriesT_co |
| 16 | +from narwhals.utils import exclude_column_names |
| 17 | +from narwhals.utils import get_column_names |
| 18 | +from narwhals.utils import passthrough_column_names |
12 | 19 |
|
13 | 20 | if TYPE_CHECKING: |
14 | | - from narwhals._compliant.expr import CompliantExpr |
15 | 21 | from narwhals._compliant.selectors import CompliantSelectorNamespace |
16 | 22 | from narwhals.dtypes import DType |
17 | 23 | from narwhals.utils import Implementation |
|
20 | 26 | __all__ = ["CompliantNamespace", "EagerNamespace"] |
21 | 27 |
|
22 | 28 |
|
23 | | -class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): |
24 | | - def col( |
25 | | - self, *column_names: str |
26 | | - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... |
27 | | - def lit( |
28 | | - self, value: Any, dtype: DType | None |
29 | | - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... |
| 29 | +class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): |
| 30 | + _implementation: Implementation |
| 31 | + _backend_version: tuple[int, ...] |
| 32 | + _version: Version |
| 33 | + |
| 34 | + def all(self) -> CompliantExprT: |
| 35 | + return self._expr.from_column_names( |
| 36 | + get_column_names, function_name="all", context=self |
| 37 | + ) |
| 38 | + |
| 39 | + def col(self, *column_names: str) -> CompliantExprT: |
| 40 | + return self._expr.from_column_names( |
| 41 | + passthrough_column_names(column_names), function_name="col", context=self |
| 42 | + ) |
| 43 | + |
| 44 | + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: |
| 45 | + return self._expr.from_column_names( |
| 46 | + partial(exclude_column_names, names=excluded_names), |
| 47 | + function_name="exclude", |
| 48 | + context=self, |
| 49 | + ) |
| 50 | + |
| 51 | + def nth(self, *column_indices: int) -> CompliantExprT: |
| 52 | + return self._expr.from_column_indices(*column_indices, context=self) |
| 53 | + |
| 54 | + def len(self) -> CompliantExprT: ... |
| 55 | + def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... |
| 56 | + def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 57 | + def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 58 | + def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 59 | + def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 60 | + def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 61 | + def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 62 | + def concat( |
| 63 | + self, |
| 64 | + items: Iterable[CompliantFrameT], |
| 65 | + *, |
| 66 | + how: Literal["horizontal", "vertical", "diagonal"], |
| 67 | + ) -> CompliantFrameT: ... |
| 68 | + def when(self, predicate: CompliantExprT) -> Any: ... |
| 69 | + def concat_str( |
| 70 | + self, |
| 71 | + *exprs: CompliantExprT, |
| 72 | + separator: str, |
| 73 | + ignore_nulls: bool, |
| 74 | + ) -> CompliantExprT: ... |
30 | 75 | @property |
31 | 76 | def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... |
| 77 | + @property |
| 78 | + def _expr(self) -> type[CompliantExprT]: ... |
32 | 79 |
|
33 | 80 |
|
34 | 81 | class EagerNamespace( |
35 | | - CompliantNamespace[EagerDataFrameT, EagerSeriesT_co], |
| 82 | + CompliantNamespace[EagerDataFrameT, EagerExprT], |
36 | 83 | Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], |
37 | 84 | ): |
38 | | - _implementation: Implementation |
39 | | - _backend_version: tuple[int, ...] |
40 | | - _version: Version |
41 | | - |
42 | | - @property |
43 | | - def _expr(self) -> type[EagerExprT]: ... |
44 | 85 | @property |
45 | 86 | def _series(self) -> type[EagerSeriesT_co]: ... |
46 | | - def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ... |
|
0 commit comments