Skip to content

Commit d42b019

Browse files
committed
Put calls in tests corrected, organized put logic
1 parent 156f7f0 commit d42b019

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
104104
raise TypeError(
105105
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
106106
)
107-
queues_ = [
108-
x.sycl_queue,
109-
]
110-
usm_types_ = [
111-
x.usm_type,
112-
]
107+
if isinstance(vals, dpt.usm_ndarray):
108+
queues_ = [x.sycl_queue, vals.sycl_queue]
109+
usm_types_ = [x.usm_type, vals.usm_type]
110+
else:
111+
queues_ = [
112+
x.sycl_queue,
113+
]
114+
usm_types_ = [
115+
x.usm_type,
116+
]
113117

114118
if not isinstance(indices, list) and not isinstance(indices, tuple):
115119
indices = (indices,)

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def test_put_0d_data(data_dt):
565565
ind = dpt.arange(5)
566566
val = dpt.asarray(2, dtype=data_dt)
567567

568-
dpt.put(x, ind, val)
568+
dpt.put(x, ind, val, axis=0)
569569
assert (
570570
dpt.asnumpy(x)
571571
== np.broadcast_to(np.asarray(2, dtype=data_dt), ind.shape)
@@ -597,7 +597,7 @@ def test_put_0d_ind(ind_dt):
597597
ind = dpt.asarray(3)
598598
val = dpt.asarray(5, dtype=ind_dt)
599599

600-
dpt.put(x, ind, val)
600+
dpt.put(x, ind, val, axis=0)
601601
assert dpt.asnumpy(x[3]) == dpt.asnumpy(val)
602602

603603

@@ -886,6 +886,8 @@ def test_take_arg_validation():
886886
ind0 = dpt.arange(2)
887887
ind1 = dpt.arange(2.0)
888888

889+
with pytest.raises(ValueError):
890+
dpt.take(dpt.reshape(x, (2, 2)), ind0)
889891
with pytest.raises(TypeError):
890892
dpt.take(dict(), ind0, axis=0)
891893
with pytest.raises(TypeError):
@@ -935,10 +937,10 @@ def test_advanced_indexing_compute_follows_data():
935937
with pytest.raises(ExecutionPlacementError):
936938
x[ind1]
937939
with pytest.raises(ExecutionPlacementError):
938-
dpt.put(x, ind1, val0)
940+
dpt.put(x, ind1, val0, axis=0)
939941
with pytest.raises(ExecutionPlacementError):
940942
x[ind1] = val0
941943
with pytest.raises(ExecutionPlacementError):
942-
dpt.put(x, ind0, val1)
944+
dpt.put(x, ind0, val1, axis=0)
943945
with pytest.raises(ExecutionPlacementError):
944946
x[ind0] = val1

0 commit comments

Comments
 (0)