Skip to content

Commit 133cfd1

Browse files
committed
Set array namespace in _create_from_usm_ndarray
1 parent 080a96a commit 133cfd1

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

dpnp/dpnp_array.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -353,15 +353,7 @@ def __getitem__(self, key):
353353
key = _get_unwrapped_index_key(key)
354354

355355
item = self._array_obj.__getitem__(key)
356-
if not isinstance(item, dpt.usm_ndarray):
357-
raise RuntimeError(
358-
"Expected dpctl.tensor.usm_ndarray, got {}"
359-
"".format(type(item))
360-
)
361-
362-
res = self.__new__(dpnp_array)
363-
res._array_obj = item
364-
return res
356+
return dpnp._create_from_usm_ndarray(item)
365357

366358
# '__getstate__',
367359

@@ -632,6 +624,7 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray):
632624
)
633625
res = dpnp_array.__new__(dpnp_array)
634626
res._array_obj = usm_ary
627+
res._array_obj._set_namespace(dpnp)
635628
return res
636629

637630
def all(self, axis=None, out=None, keepdims=False, *, where=True):
@@ -1775,17 +1768,16 @@ def transpose(self, *axes):
17751768
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
17761769
axes = axes[0]
17771770

1778-
res = self.__new__(dpnp_array)
17791771
if ndim == 2 and axes_len == 0:
1780-
res._array_obj = self._array_obj.T
1772+
usm_res = self._array_obj.T
17811773
else:
17821774
if len(axes) == 0 or axes[0] is None:
17831775
# self.transpose().shape == self.shape[::-1]
17841776
# self.transpose(None).shape == self.shape[::-1]
17851777
axes = tuple((ndim - x - 1) for x in range(ndim))
17861778

1787-
res._array_obj = dpt.permute_dims(self._array_obj, axes)
1788-
return res
1779+
usm_res = dpt.permute_dims(self._array_obj, axes)
1780+
return dpnp._create_from_usm_ndarray(usm_res)
17891781

17901782
def var(
17911783
self,

0 commit comments

Comments
 (0)