@@ -94,6 +94,7 @@ def __init__(
9494 offset = offset ,
9595 order = order ,
9696 buffer_ctor_kwargs = {"queue" : sycl_queue_normalized },
97+ array_namespace = dpnp ,
9798 )
9899
99100 @property
@@ -201,6 +202,31 @@ def __and__(self, other):
201202 # '__array_ufunc__',
202203 # '__array_wrap__',
203204
205+ def __array_namespace__ (self , / , * , api_version = None ):
206+ """
207+ Returns array namespace, member functions of which implement data API.
208+
209+ Parameters
210+ ----------
211+ api_version : str, optional
212+ Request namespace compliant with given version of array API. If
213+ ``None``, namespace for the most recent supported version is
214+ returned.
215+ Default: ``None``.
216+
217+ Returns
218+ -------
219+ out : any
220+ An object representing the array API namespace. It should have
221+ every top-level function defined in the specification as
222+ an attribute. It may contain other public names as well, but it is
223+ recommended to only include those names that are part of the
224+ specification.
225+
226+ """
227+
228+ return self ._array_obj .__array_namespace__ (api_version = api_version )
229+
204230 def __bool__ (self ):
205231 """``True`` if self else ``False``."""
206232 return self ._array_obj .__bool__ ()
@@ -327,15 +353,7 @@ def __getitem__(self, key):
327353 key = _get_unwrapped_index_key (key )
328354
329355 item = self ._array_obj .__getitem__ (key )
330- if not isinstance (item , dpt .usm_ndarray ):
331- raise RuntimeError (
332- "Expected dpctl.tensor.usm_ndarray, got {}"
333- "" .format (type (item ))
334- )
335-
336- res = self .__new__ (dpnp_array )
337- res ._array_obj = item
338- return res
356+ return dpnp_array ._create_from_usm_ndarray (item )
339357
340358 # '__getstate__',
341359
@@ -606,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
606624 )
607625 res = dpnp_array .__new__ (dpnp_array )
608626 res ._array_obj = usm_ary
627+ res ._array_obj ._set_namespace (dpnp )
609628 return res
610629
611630 def all (self , axis = None , out = None , keepdims = False , * , where = True ):
@@ -1749,17 +1768,16 @@ def transpose(self, *axes):
17491768 if axes_len == 1 and isinstance (axes [0 ], (tuple , list )):
17501769 axes = axes [0 ]
17511770
1752- res = self .__new__ (dpnp_array )
17531771 if ndim == 2 and axes_len == 0 :
1754- res . _array_obj = self ._array_obj .T
1772+ usm_res = self ._array_obj .T
17551773 else :
17561774 if len (axes ) == 0 or axes [0 ] is None :
17571775 # self.transpose().shape == self.shape[::-1]
17581776 # self.transpose(None).shape == self.shape[::-1]
17591777 axes = tuple ((ndim - x - 1 ) for x in range (ndim ))
17601778
1761- res . _array_obj = dpt .permute_dims (self ._array_obj , axes )
1762- return res
1779+ usm_res = dpt .permute_dims (self ._array_obj , axes )
1780+ return dpnp_array . _create_from_usm_ndarray ( usm_res )
17631781
17641782 def var (
17651783 self ,
0 commit comments