@@ -170,9 +170,8 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
170170 raise IndexError ("cannot take non-empty indices from an empty axis" )
171171 res_sh = x_sh [:axis ] + inds .shape + x_sh [axis_end :]
172172
173- orig_out = None
174173 if out is not None :
175- orig_out = out = dpnp .get_usm_ndarray (out )
174+ out = dpnp .get_usm_ndarray (out )
176175
177176 if not out .flags .writable :
178177 raise ValueError ("provided `out` array is read-only" )
@@ -184,7 +183,7 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
184183 )
185184
186185 if x .dtype != out .dtype :
187- raise ValueError (
186+ raise TypeError (
188187 f"Output array of type { x .dtype } is needed, " f"got { out .dtype } "
189188 )
190189
@@ -213,14 +212,6 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
213212 )
214213 _manager .add_event_pair (h_ev , take_ev )
215214
216- if not (orig_out is None or orig_out is out ):
217- # Copy the out data from temporary buffer to original memory
218- ht_copy_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
219- src = out , dst = orig_out , sycl_queue = q , depends = [take_ev ]
220- )
221- _manager .add_event_pair (ht_copy_ev , cpy_ev )
222- out = orig_out
223-
224215 return out
225216
226217
@@ -317,10 +308,9 @@ def compress(condition, a, axis=None, out=None):
317308 # _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
318309 inds = _nonzero_impl (cond_ary )
319310
320- return dpnp .get_result_array (
321- _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out ),
322- out = out ,
323- )
311+ res = _take_index (a_ary , inds [0 ], axis , exec_q , res_usm_type , out = out )
312+
313+ return dpnp .get_result_array (res , out = out )
324314
325315
326316def diag_indices (n , ndim = 2 , device = None , usm_type = "device" , sycl_queue = None ):
0 commit comments