@@ -353,15 +353,7 @@ def __getitem__(self, key):
353353 key = _get_unwrapped_index_key (key )
354354
355355 item = self ._array_obj .__getitem__ (key )
356- if not isinstance (item , dpt .usm_ndarray ):
357- raise RuntimeError (
358- "Expected dpctl.tensor.usm_ndarray, got {}"
359- "" .format (type (item ))
360- )
361-
362- res = self .__new__ (dpnp_array )
363- res ._array_obj = item
364- return res
356+ return dpnp ._create_from_usm_ndarray (item )
365357
366358 # '__getstate__',
367359
@@ -632,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
632624 )
633625 res = dpnp_array .__new__ (dpnp_array )
634626 res ._array_obj = usm_ary
627+ res ._array_obj ._set_namespace (dpnp )
635628 return res
636629
637630 def all (self , axis = None , out = None , keepdims = False , * , where = True ):
@@ -1775,17 +1768,16 @@ def transpose(self, *axes):
17751768 if axes_len == 1 and isinstance (axes [0 ], (tuple , list )):
17761769 axes = axes [0 ]
17771770
1778- res = self .__new__ (dpnp_array )
17791771 if ndim == 2 and axes_len == 0 :
1780- res . _array_obj = self ._array_obj .T
1772+ usm_res = self ._array_obj .T
17811773 else :
17821774 if len (axes ) == 0 or axes [0 ] is None :
17831775 # self.transpose().shape == self.shape[::-1]
17841776 # self.transpose(None).shape == self.shape[::-1]
17851777 axes = tuple ((ndim - x - 1 ) for x in range (ndim ))
17861778
1787- res . _array_obj = dpt .permute_dims (self ._array_obj , axes )
1788- return res
1779+ usm_res = dpt .permute_dims (self ._array_obj , axes )
1780+ return dpnp . _create_from_usm_ndarray ( usm_res )
17891781
17901782 def var (
17911783 self ,
0 commit comments