Skip to content

Commit efdbbc0

Browse files
committed
ENH: Add __array__ to the array_api Array object
This is *NOT* part of the array API spec (so it should not be relied on for portable code). However, without this, np.asarray(np.array_api.Array) produces an object array instead of doing the conversion to a NumPy array as expected. This would work once np.asarray() implements dlpack support, but until then, it seems reasonable to make the conversion work. Note that the reverse, calling np.array_api.asarray(np.array), already works because np.array_api.asarray() is just a wrapper for np.asarray(). Original NumPy Commit: 74a3ee7a8b75bf6dc271c9a1a4b55d2ad9758420
1 parent 3d1674e commit efdbbc0

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

array_api_strict/_array_object.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ def __repr__(self: Array, /) -> str:
108108
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
109109
return prefix + mid + suffix
110110

111+
# This function is not required by the spec, but we implement it here for
112+
# convenience so that np.asarray(np.array_api.Array) will work.
113+
def __array__(self, dtype=None):
114+
"""
115+
Warning: this method is NOT part of the array API spec. Implementers
116+
of other libraries need not include it, and users should not assume it
117+
will be present in other implementations.
118+
119+
"""
120+
return np.asarray(self._array, dtype=dtype)
121+
111122
# These are various helper functions to make the array behavior match the
112123
# spec in places where it either deviates from or is more strict than
113124
# NumPy behavior

array_api_strict/tests/test_array_object.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,10 @@ def test_array_properties():
315315
assert a.mT.shape == (1, 3, 2)
316316
assert isinstance(b.mT, Array)
317317
assert b.mT.shape == (3, 2)
318+
319+
def test___array__():
320+
a = ones((2, 3), dtype=int16)
321+
assert np.asarray(a) is a._array
322+
b = np.asarray(a, dtype=np.float64)
323+
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
324+
assert b.dtype == np.float64

0 commit comments

Comments
 (0)