Skip to content

Commit fbbc9e1

Browse files
committed
✨ HasDevice
Signed-off-by: nstarman <[email protected]>
1 parent 3cb0930 commit fbbc9e1

File tree

5 files changed

+29
-0
lines changed

5 files changed

+29
-0
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ ignore = [
126126
"FIX", # flake8-fixme
127127
"ISC001", # Conflicts with formatter
128128
"PYI041", # Use `float` instead of `int | float`
129+
"TD002", # Missing author in TODO
130+
"TD003", # Missing issue link for this TODO
129131
]
130132

131133
[tool.ruff.lint.pydocstyle]

src/array_api_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
"Array",
55
"HasArrayNamespace",
66
"HasDType",
7+
"HasDevice",
78
"__version__",
89
"__version_tuple__",
910
)
1011

1112
from ._array import (
1213
Array,
1314
HasArrayNamespace,
15+
HasDevice,
1416
HasDType,
1517
)
1618
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"Array",
33
"HasArrayNamespace",
44
"HasDType",
5+
"HasDevice",
56
)
67

78
from types import ModuleType
@@ -68,6 +69,15 @@ def dtype(self, /) -> DTypeT_co:
6869
...
6970

7071

72+
class HasDevice(Protocol):
73+
"""Protocol for array classes that have a device attribute."""
74+
75+
@property
76+
def device(self) -> object: # TODO: more specific type
77+
"""Hardware device the array data resides on."""
78+
...
79+
80+
7181
class Array(
7282
# ------ Attributes -------
7383
HasDType[DTypeT_co],

tests/integration/test_numpy1p0.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ _: xpt.HasDType[dtype[Any]] = nparr
3939
_: xpt.HasDType[dtype[Any]] = nparr_i32
4040
_: xpt.HasDType[dtype[Any]] = nparr_f32
4141

42+
# =========================================================
43+
# `xpt.HasDevice`
44+
45+
_: xpt.HasDevice = nparr
46+
_: xpt.HasDevice = nparr_i32
47+
_: xpt.HasDevice = nparr_f32
48+
4249
# =========================================================
4350
# `xpt.Array`
4451

tests/integration/test_numpy2p0.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ _: xpt.HasDType[np.dtype[I32]] = nparr_i32
4545
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
4646
_: xpt.HasDType[np.dtype[B]] = nparr_b
4747

48+
# =========================================================
49+
# `xpt.HasDevice`
50+
51+
_: xpt.HasDevice = nparr
52+
_: xpt.HasDevice = nparr_i32
53+
_: xpt.HasDevice = nparr_f32
54+
_: xpt.HasDevice = nparr_b
55+
4856
# =========================================================
4957
# `xpt.Array`
5058

0 commit comments

Comments
 (0)