Skip to content

Commit 23ec32b

Browse files
committed
🚚 move HasArrayNamespace
Signed-off-by: nstarman <[email protected]>
1 parent 85ff8ac commit 23ec32b

File tree

4 files changed

+72
-16
lines changed

4 files changed

+72
-16
lines changed

‎src/array_api_typing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
"__version_tuple__",
77
)
88

9-
from ._namespace import HasArrayNamespace
9+
from ._array import HasArrayNamespace
1010
from ._version import version as __version__, version_tuple as __version_tuple__

‎src/array_api_typing/_namespace.py renamed to ‎src/array_api_typing/_array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from typing import Literal, Protocol
55
from typing_extensions import TypeVar
66

7-
T_co = TypeVar("T_co", covariant=True, default=ModuleType)
7+
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
88

99

10-
class HasArrayNamespace(Protocol[T_co]):
10+
class HasArrayNamespace(Protocol[NamespaceT_co]):
1111
"""Protocol for classes that have an `__array_namespace__` method.
1212
13+
This `Protocol` is intended for use in static typing to ensure that an
14+
object has an `__array_namespace__` method that returns a namespace for
15+
array operations. This `Protocol` should not be used at runtime, for type
16+
checking or as a base class.
17+
1318
Example:
1419
>>> import array_api_typing as xpt
1520
>>>
@@ -27,4 +32,4 @@ class HasArrayNamespace(Protocol[T_co]):
2732

2833
def __array_namespace__(
2934
self, /, *, api_version: Literal["2021.12"] | None = None
30-
) -> T_co: ...
35+
) -> NamespaceT_co: ...

‎tests/integration/test_numpy1.pyi

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,36 @@
1-
from typing import Any
1+
# mypy: disable-error-code="no-redef"
22

3-
# requires numpy < 2
4-
import numpy.array_api as np
3+
from types import ModuleType
4+
from typing import TypeAlias
5+
6+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
57

68
import array_api_typing as xpt
79

8-
###
9-
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`.
10+
# DType aliases
11+
F32: TypeAlias = np.float32
12+
I32: TypeAlias = np.int32
13+
14+
# Define NDArrays against which we can test the protocols
15+
nparr = np.eye(2)
16+
nparr_i32 = np.array([1], dtype=I32)
17+
nparr_f32 = np.array([1.0], dtype=F32)
18+
nparr_b = np.array([True], dtype=np.bool_)
19+
20+
# =========================================================
21+
# `xpt.HasArrayNamespace`
22+
23+
_: xpt.HasArrayNamespace[ModuleType] = nparr
24+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
27+
28+
# Check `__array_namespace__` method
29+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
30+
ns: ModuleType = a_ns.__array_namespace__()
1031

11-
arr = np.eye(2)
12-
arr_namespace: xpt.HasArrayNamespace[Any] = arr
32+
# Incorrect values are caught when using `__array_namespace__` and
33+
# backpropagated to the type of `a_ns`
34+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
35+
a_badns: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]
36+
a_badns.__array_namespace__() # triggers error above

‎tests/integration/test_numpy2.pyi

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
1-
from typing import Any
1+
# mypy: disable-error-code="no-redef"
22

3+
from types import ModuleType
4+
from typing import Any, TypeAlias
5+
6+
import numpy as np
37
import numpy.typing as npt
48

59
import array_api_typing as xpt
610

7-
###
8-
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`.
11+
# DType aliases
12+
F32: TypeAlias = np.float32
13+
I32: TypeAlias = np.int32
14+
15+
# Define NDArrays against which we can test the protocols
16+
nparr: npt.NDArray[Any]
17+
nparr_i32: npt.NDArray[I32] = np.array([1], dtype=I32)
18+
nparr_f32: npt.NDArray[F32] = np.array([1.0], dtype=F32)
19+
nparr_b: npt.NDArray[np.bool_] = np.array([True], dtype=np.bool_)
20+
21+
# =========================================================
22+
# `xpt.HasArrayNamespace`
23+
24+
# Check assignment
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
27+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
28+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
29+
30+
# Check `__array_namespace__` method
31+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
32+
ns: ModuleType = a_ns.__array_namespace__()
933

10-
arr: npt.NDArray[Any]
11-
arr_namespace: xpt.HasArrayNamespace[Any] = arr
34+
# Incorrect values are caught when using `__array_namespace__` and
35+
# backpropagated to the type of `a_ns`
36+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
37+
a_badns: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]
38+
a_badns.__array_namespace__() # triggers error above

0 commit comments

Comments
 (0)