|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from functools import partial |
3 | 4 | from typing import TYPE_CHECKING |
4 | 5 | from typing import Any |
| 6 | +from typing import Container |
| 7 | +from typing import Iterable |
| 8 | +from typing import Literal |
5 | 9 | from typing import Protocol |
6 | 10 |
|
| 11 | +from narwhals._compliant.typing import CompliantExprT |
7 | 12 | from narwhals._compliant.typing import CompliantFrameT |
8 | | -from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co |
9 | 13 | from narwhals._compliant.typing import EagerDataFrameT |
10 | 14 | from narwhals._compliant.typing import EagerExprT |
11 | 15 | from narwhals._compliant.typing import EagerSeriesT_co |
12 | 16 | from narwhals.utils import deprecated |
| 17 | +from narwhals.utils import exclude_column_names |
| 18 | +from narwhals.utils import get_column_names |
| 19 | +from narwhals.utils import passthrough_column_names |
13 | 20 |
|
14 | 21 | if TYPE_CHECKING: |
15 | | - from narwhals._compliant.expr import CompliantExpr |
16 | 22 | from narwhals._compliant.selectors import CompliantSelectorNamespace |
17 | 23 | from narwhals.dtypes import DType |
| 24 | + from narwhals.utils import Implementation |
| 25 | + from narwhals.utils import Version |
18 | 26 |
|
19 | 27 | __all__ = ["CompliantNamespace", "EagerNamespace"] |
20 | 28 |
|
21 | 29 |
|
22 | | -class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): |
23 | | - def col( |
24 | | - self, *column_names: str |
25 | | - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... |
26 | | - def lit( |
27 | | - self, value: Any, dtype: DType | None |
28 | | - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... |
| 30 | +class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): |
| 31 | + _implementation: Implementation |
| 32 | + _backend_version: tuple[int, ...] |
| 33 | + _version: Version |
| 34 | + |
| 35 | + def all(self) -> CompliantExprT: |
| 36 | + return self._expr.from_column_names( |
| 37 | + get_column_names, function_name="all", context=self |
| 38 | + ) |
| 39 | + |
| 40 | + def col(self, *column_names: str) -> CompliantExprT: |
| 41 | + return self._expr.from_column_names( |
| 42 | + passthrough_column_names(column_names), function_name="col", context=self |
| 43 | + ) |
| 44 | + |
| 45 | + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: |
| 46 | + return self._expr.from_column_names( |
| 47 | + partial(exclude_column_names, names=excluded_names), |
| 48 | + function_name="exclude", |
| 49 | + context=self, |
| 50 | + ) |
| 51 | + |
| 52 | + def nth(self, *column_indices: int) -> CompliantExprT: |
| 53 | + return self._expr.from_column_indices(*column_indices, context=self) |
| 54 | + |
| 55 | + def len(self) -> CompliantExprT: ... |
| 56 | + def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... |
| 57 | + def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 58 | + def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 59 | + def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 60 | + def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 61 | + def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 62 | + def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... |
| 63 | + def concat( |
| 64 | + self, |
| 65 | + items: Iterable[CompliantFrameT], |
| 66 | + *, |
| 67 | + how: Literal["horizontal", "vertical", "diagonal"], |
| 68 | + ) -> CompliantFrameT: ... |
| 69 | + def when(self, predicate: CompliantExprT) -> Any: ... |
| 70 | + def concat_str( |
| 71 | + self, |
| 72 | + *exprs: CompliantExprT, |
| 73 | + separator: str, |
| 74 | + ignore_nulls: bool, |
| 75 | + ) -> CompliantExprT: ... |
29 | 76 | @property |
30 | 77 | def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... |
| 78 | + @property |
| 79 | + def _expr(self) -> type[CompliantExprT]: ... |
31 | 80 |
|
32 | 81 |
|
33 | 82 | class EagerNamespace( |
34 | | - CompliantNamespace[EagerDataFrameT, EagerSeriesT_co], |
| 83 | + CompliantNamespace[EagerDataFrameT, EagerExprT], |
35 | 84 | Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], |
36 | 85 | ): |
37 | | - @property |
38 | | - def _expr(self) -> type[EagerExprT]: ... |
39 | 86 | @property |
40 | 87 | def _series(self) -> type[EagerSeriesT_co]: ... |
41 | | - def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ... |
42 | 88 |
|
43 | 89 | @deprecated( |
44 | 90 | "Internally used for `numpy.ndarray` -> `CompliantSeries`\n" |
|
0 commit comments