diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 13e57ef1ee2a..0fc0fa402806 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -94,6 +94,7 @@ def __init__( offset=offset, order=order, buffer_ctor_kwargs={"queue": sycl_queue_normalized}, + array_namespace=dpnp, ) @property @@ -201,6 +202,31 @@ def __and__(self, other): # '__array_ufunc__', # '__array_wrap__', + def __array_namespace__(self, /, *, api_version=None): + """ + Returns array namespace, member functions of which implement data API. + + Parameters + ---------- + api_version : str, optional + Request namespace compliant with given version of array API. If + ``None``, namespace for the most recent supported version is + returned. + Default: ``None``. + + Returns + ------- + out : any + An object representing the array API namespace. It should have + every top-level function defined in the specification as + an attribute. It may contain other public names as well, but it is + recommended to only include those names that are part of the + specification. + + """ + + return self._array_obj.__array_namespace__(api_version=api_version) + def __bool__(self): """``True`` if self else ``False``.""" return self._array_obj.__bool__() @@ -327,15 +353,7 @@ def __getitem__(self, key): key = _get_unwrapped_index_key(key) item = self._array_obj.__getitem__(key) - if not isinstance(item, dpt.usm_ndarray): - raise RuntimeError( - "Expected dpctl.tensor.usm_ndarray, got {}" - "".format(type(item)) - ) - - res = self.__new__(dpnp_array) - res._array_obj = item - return res + return dpnp_array._create_from_usm_ndarray(item) # '__getstate__', @@ -606,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray): ) res = dpnp_array.__new__(dpnp_array) res._array_obj = usm_ary + res._array_obj._set_namespace(dpnp) return res def all(self, axis=None, out=None, keepdims=False, *, where=True): @@ -1749,17 +1768,16 @@ def transpose(self, *axes): if axes_len == 1 and isinstance(axes[0], (tuple, list)): axes = axes[0] - res = self.__new__(dpnp_array) if ndim == 2 and axes_len == 0: - res._array_obj = self._array_obj.T + usm_res = self._array_obj.T else: if len(axes) == 0 or axes[0] is None: # self.transpose().shape == self.shape[::-1] # self.transpose(None).shape == self.shape[::-1] axes = tuple((ndim - x - 1) for x in range(ndim)) - res._array_obj = dpt.permute_dims(self._array_obj, axes) - return res + usm_res = dpt.permute_dims(self._array_obj, axes) + return dpnp_array._create_from_usm_ndarray(usm_res) def var( self, diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index a5d172155dbe..46e0d610cedf 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -622,14 +622,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1): out_strides = a_straides[:-2] + (1,) out_offset = a_element_offset - return dpnp_array._create_from_usm_ndarray( - dpt.usm_ndarray( - out_shape, - dtype=a.dtype, - buffer=a.get_array(), - strides=out_strides, - offset=out_offset, - ) + return dpnp_array( + out_shape, buffer=a, strides=out_strides, offset=out_offset ) diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 153b3d0c2e9e..a184af6ba225 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -1,7 +1,11 @@ import dpctl.tensor as dpt import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import ( + assert_allclose, + assert_array_equal, + assert_raises_regex, +) import dpnp @@ -104,6 +108,48 @@ def test_flags_writable(): assert not a.imag.flags.writable +class TestArrayNamespace: + def test_basic(self): + a = dpnp.arange(2) + xp = a.__array_namespace__() + assert xp is dpnp + + @pytest.mark.parametrize("api_version", [None, "2023.12"]) + def test_api_version(self, api_version): + a = dpnp.arange(2) + xp = a.__array_namespace__(api_version=api_version) + assert xp is dpnp + + @pytest.mark.parametrize("api_version", ["2021.12", "2022.12", "2024.12"]) + def test_unsupported_api_version(self, api_version): + a = dpnp.arange(2) + assert_raises_regex( + ValueError, + "Only 2023.12 is supported", + a.__array_namespace__, + api_version=api_version, + ) + + @pytest.mark.parametrize( + "api_version", + [ + 2023, + (2022,), + [ + 2021, + ], + ], + ) + def test_wrong_api_version(self, api_version): + a = dpnp.arange(2) + assert_raises_regex( + TypeError, + "Expected type str", + a.__array_namespace__, + api_version=api_version, + ) + + class TestItem: @pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)]) def test_basic(self, args):