Skip to content

Commit 5818d10

Browse files
committed
Change error for incorrect out array dtype to TypeError
1 parent e53b84e commit 5818d10

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,8 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
170170
raise IndexError("cannot take non-empty indices from an empty axis")
171171
res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:]
172172

173-
orig_out = None
174173
if out is not None:
175-
orig_out = out = dpnp.get_usm_ndarray(out)
174+
out = dpnp.get_usm_ndarray(out)
176175

177176
if not out.flags.writable:
178177
raise ValueError("provided `out` array is read-only")
@@ -184,7 +183,7 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
184183
)
185184

186185
if x.dtype != out.dtype:
187-
raise ValueError(
186+
raise TypeError(
188187
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
189188
)
190189

@@ -213,14 +212,6 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
213212
)
214213
_manager.add_event_pair(h_ev, take_ev)
215214

216-
if not (orig_out is None or orig_out is out):
217-
# Copy the out data from temporary buffer to original memory
218-
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
219-
src=out, dst=orig_out, sycl_queue=q, depends=[take_ev]
220-
)
221-
_manager.add_event_pair(ht_copy_ev, cpy_ev)
222-
out = orig_out
223-
224215
return out
225216

226217

@@ -317,10 +308,9 @@ def compress(condition, a, axis=None, out=None):
317308
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
318309
inds = _nonzero_impl(cond_ary)
319310

320-
return dpnp.get_result_array(
321-
_take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out),
322-
out=out,
323-
)
311+
res = _take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out)
312+
313+
return dpnp.get_result_array(res, out=out)
324314

325315

326316
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):

dpnp/tests/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,7 @@ def test_compress_invalid_out_errors():
13651365
with pytest.raises(ExecutionPlacementError):
13661366
dpnp.compress(condition, a, out=out_bad_queue)
13671367
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
1368-
with pytest.raises(ValueError):
1368+
with pytest.raises(TypeError):
13691369
dpnp.compress(condition, a, out=out_bad_dt)
13701370
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
13711371
out_read_only.flags.writable = False

0 commit comments

Comments
 (0)