|
4 | 4 | from typing import Any |
5 | 5 | from typing import Iterable |
6 | 6 | from typing import Iterator |
7 | | -from typing import Literal |
8 | 7 | from typing import Mapping |
9 | 8 | from typing import Sequence |
10 | 9 | from typing import cast |
|
56 | 55 | from narwhals._arrow.typing import _AsPyType |
57 | 56 | from narwhals._arrow.typing import _BasicDataType |
58 | 57 | from narwhals.dtypes import DType |
| 58 | + from narwhals.typing import ClosedInterval |
| 59 | + from narwhals.typing import FillNullStrategy |
59 | 60 | from narwhals.typing import Into1DArray |
| 61 | + from narwhals.typing import RankMethod |
| 62 | + from narwhals.typing import RollingInterpolationMethod |
60 | 63 | from narwhals.typing import _1DArray |
61 | 64 | from narwhals.typing import _2DArray |
62 | 65 | from narwhals.utils import Version |
@@ -499,10 +502,7 @@ def all(self: Self, *, _return_py_scalar: bool = True) -> bool: |
499 | 502 | ) |
500 | 503 |
|
501 | 504 | def is_between( |
502 | | - self: Self, |
503 | | - lower_bound: Any, |
504 | | - upper_bound: Any, |
505 | | - closed: Literal["left", "right", "none", "both"], |
| 505 | + self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval |
506 | 506 | ) -> Self: |
507 | 507 | _, lower_bound = extract_native(self, lower_bound) |
508 | 508 | _, upper_bound = extract_native(self, upper_bound) |
@@ -636,17 +636,14 @@ def sample( |
636 | 636 | return self._with_native(self.native.take(mask)) |
637 | 637 |
|
638 | 638 | def fill_null( |
639 | | - self: Self, |
640 | | - value: Any | None, |
641 | | - strategy: Literal["forward", "backward"] | None, |
642 | | - limit: int | None, |
| 639 | + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None |
643 | 640 | ) -> Self: |
644 | 641 | import numpy as np # ignore-banned-import |
645 | 642 |
|
646 | 643 | def fill_aux( |
647 | 644 | arr: ArrowArray | ArrowChunkedArray, |
648 | 645 | limit: int, |
649 | | - direction: Literal["forward", "backward"] | None = None, |
| 646 | + direction: FillNullStrategy | None = None, |
650 | 647 | ) -> ArrowArray: |
651 | 648 | # this algorithm first finds the indices of the valid values to fill all the null value positions |
652 | 649 | # then it calculates the distance of each new index and the original index |
@@ -812,9 +809,9 @@ def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFram |
812 | 809 | ).simple_select(*output_order) |
813 | 810 |
|
814 | 811 | def quantile( |
815 | | - self: Self, |
| 812 | + self, |
816 | 813 | quantile: float, |
817 | | - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], |
| 814 | + interpolation: RollingInterpolationMethod, |
818 | 815 | *, |
819 | 816 | _return_py_scalar: bool = True, |
820 | 817 | ) -> float: |
@@ -1028,12 +1025,7 @@ def rolling_std( |
1028 | 1025 | ** 0.5 |
1029 | 1026 | ) |
1030 | 1027 |
|
1031 | | - def rank( |
1032 | | - self: Self, |
1033 | | - method: Literal["average", "min", "max", "dense", "ordinal"], |
1034 | | - *, |
1035 | | - descending: bool, |
1036 | | - ) -> Self: |
| 1028 | + def rank(self, method: RankMethod, *, descending: bool) -> Self: |
1037 | 1029 | if method == "average": |
1038 | 1030 | msg = ( |
1039 | 1031 | "`rank` with `method='average' is not supported for pyarrow backend. " |
|
0 commit comments