|
1 |
| -from collections.abc import Iterable |
2 |
| -from typing import Any |
3 |
| - |
4 |
| -import numpy as np |
5 |
| -import numpy.typing as npt |
| 1 | +from collections.abc import Iterable, Mapping |
| 2 | +from typing import Any, Protocol, type_check_only |
6 | 3 |
|
7 | 4 | __all__ = ["byte_bounds", "normalize_axis_index", "normalize_axis_tuple"]
|
8 | 5 |
|
| 6 | +### |
| 7 | + |
| 8 | +@type_check_only |
| 9 | +class _HasSizeAndArrayInterface(Protocol): |
| 10 | + @property |
| 11 | + def size(self, /) -> int: ... |
| 12 | + @property # `TypedDict` cannot be used because it rejects `dict[str, Any]` |
| 13 | + def __array_interface__(self, /) -> Mapping[str, Any]: ... |
| 14 | + |
| 15 | +### |
| 16 | + |
9 | 17 | # NOTE: In practice `byte_bounds` can (potentially) take any object
|
10 | 18 | # implementing the `__array_interface__` protocol. The caveat is
|
11 | 19 | # that certain keys, marked as optional in the spec, must be present for
|
12 | 20 | # `byte_bounds`. This concerns `"strides"` and `"data"`.
|
13 |
| -def byte_bounds(a: np.generic | npt.NDArray[Any]) -> tuple[int, int]: ... |
14 |
| -def normalize_axis_index(axis: int = ..., ndim: int = ..., msg_prefix: str | None = ...) -> int: ... |
| 21 | +def byte_bounds(a: _HasSizeAndArrayInterface) -> tuple[int, int]: ... |
| 22 | + |
| 23 | +### |
| 24 | +def normalize_axis_index(axis: int, ndim: int, msg_prefix: str | None = None) -> int: ... |
15 | 25 | def normalize_axis_tuple(
|
16 | 26 | axis: int | Iterable[int],
|
17 |
| - ndim: int = ..., |
18 |
| - argname: str | None = ..., |
19 |
| - allow_duplicate: bool | None = ..., |
20 |
| -) -> tuple[int, int]: ... |
| 27 | + ndim: int, |
| 28 | + argname: str | None = None, |
| 29 | + allow_duplicate: bool = False, |
| 30 | +) -> tuple[int, ...]: ... |
0 commit comments