Skip to content

Commit 877c3c7

Browse files
committed
Test fixes
- Error for non-integer usm_ndarrays used as indices changed to IndexError
1 parent d42b019 commit 877c3c7

File tree

2 files changed

+53
-28
lines changed

2 files changed

+53
-28
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def take(x, indices, /, *, axis=None, mode="clip"):
4848
)
4949
)
5050
if not np.issubdtype(i.dtype, np.integer):
51-
raise TypeError(
51+
raise IndexError(
5252
"`indices` expected integer data type, got `{}`".format(i.dtype)
5353
)
5454
queues_.append(i.sycl_queue)
@@ -126,7 +126,7 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
126126
)
127127
)
128128
if not np.issubdtype(i.dtype, np.integer):
129-
raise TypeError(
129+
raise IndexError(
130130
"`indices` expected integer data type, got `{}`".format(i.dtype)
131131
)
132132
queues_.append(i.sycl_queue)

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,21 @@ def test_put_basic_axis():
535535
assert (expected == dpt.asnumpy(x)).all()
536536

537537

538+
@pytest.mark.parametrize("data_dt", _all_dtypes)
539+
def test_put_0d_val(data_dt):
540+
q = get_queue_or_skip()
541+
skip_if_dtype_not_supported(data_dt, q)
542+
543+
x = dpt.arange(5, dtype=data_dt, sycl_queue=q)
544+
ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q)
545+
x[ind] = 2
546+
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0]))
547+
548+
x = dpt.asarray(5, dtype=data_dt, sycl_queue=q)
549+
x[ind] = 2
550+
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x))
551+
552+
538553
@pytest.mark.parametrize(
539554
"data_dt",
540555
_all_dtypes,
@@ -543,8 +558,8 @@ def test_take_0d_data(data_dt):
543558
q = get_queue_or_skip()
544559
skip_if_dtype_not_supported(data_dt, q)
545560

546-
x = dpt.asarray(0, dtype=data_dt)
547-
ind = dpt.arange(5)
561+
x = dpt.asarray(0, dtype=data_dt, sycl_queue=q)
562+
ind = dpt.arange(5, dtype=np.intp, sycl_queue=q)
548563

549564
y = dpt.take(x, ind)
550565
assert (
@@ -561,9 +576,9 @@ def test_put_0d_data(data_dt):
561576
q = get_queue_or_skip()
562577
skip_if_dtype_not_supported(data_dt, q)
563578

564-
x = dpt.asarray(0, dtype=data_dt)
565-
ind = dpt.arange(5)
566-
val = dpt.asarray(2, dtype=data_dt)
579+
x = dpt.asarray(0, dtype=data_dt, sycl_queue=q)
580+
ind = dpt.arange(5, dtype=np.intp, sycl_queue=q)
581+
val = dpt.asarray(2, dtype=data_dt, sycl_queue=q)
567582

568583
dpt.put(x, ind, val, axis=0)
569584
assert (
@@ -577,10 +592,10 @@ def test_put_0d_data(data_dt):
577592
_all_int_dtypes,
578593
)
579594
def test_take_0d_ind(ind_dt):
580-
get_queue_or_skip()
595+
q = get_queue_or_skip()
581596

582-
x = dpt.arange(5, dtype=ind_dt)
583-
ind = dpt.asarray(3)
597+
x = dpt.arange(5, dtype="i4", sycl_queue=q)
598+
ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q)
584599

585600
y = dpt.take(x, ind)
586601
assert dpt.asnumpy(x[3]) == dpt.asnumpy(y)
@@ -591,11 +606,11 @@ def test_take_0d_ind(ind_dt):
591606
_all_int_dtypes,
592607
)
593608
def test_put_0d_ind(ind_dt):
594-
get_queue_or_skip()
609+
q = get_queue_or_skip()
595610

596-
x = dpt.arange(5, dtype=ind_dt)
597-
ind = dpt.asarray(3)
598-
val = dpt.asarray(5, dtype=ind_dt)
611+
x = dpt.arange(5, dtype="i4", sycl_queue=q)
612+
ind = dpt.asarray(3, dtype=ind_dt, sycl_queue=q)
613+
val = dpt.asarray(5, dtype=x.dtype, sycl_queue=q)
599614

600615
dpt.put(x, ind, val, axis=0)
601616
assert dpt.asnumpy(x[3]) == dpt.asnumpy(val)
@@ -750,7 +765,7 @@ def test_put_strided_1d_destination(data_dt, order):
750765

751766
x = dpt.arange(27, dtype=data_dt, sycl_queue=q)
752767
ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q)
753-
val = dpt.asarray(9, dtype=data_dt, sycl_queue=q)
768+
val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q)
754769

755770
x_np = dpt.asnumpy(x)
756771
ind_np = dpt.asnumpy(ind)
@@ -780,7 +795,7 @@ def test_put_strided_destination(data_dt, order):
780795

781796
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
782797
ind = dpt.arange(2, dtype=np.intp, sycl_queue=q)
783-
val = dpt.asarray(9, dtype=data_dt, sycl_queue=q)
798+
val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q)
784799

785800
x_np = dpt.asnumpy(x)
786801
ind_np = dpt.asnumpy(ind)
@@ -825,7 +840,7 @@ def test_put_strided_1d_indices(ind_dt):
825840

826841
x = dpt.arange(27, dtype="i4", sycl_queue=q)
827842
ind = dpt.arange(12, 24, dtype=ind_dt, sycl_queue=q)
828-
val = dpt.asarray(-1, dtype="i4", sycl_queue=q)
843+
val = dpt.asarray(-1, dtype=x.dtype, sycl_queue=q)
829844

830845
x_np = dpt.asnumpy(x)
831846
ind_np = dpt.asnumpy(ind).astype(np.intp)
@@ -880,43 +895,53 @@ def test_put_strided_indices(ind_dt, order):
880895

881896

882897
def test_take_arg_validation():
883-
get_queue_or_skip()
898+
q = get_queue_or_skip()
884899

885-
x = dpt.arange(4)
886-
ind0 = dpt.arange(2)
887-
ind1 = dpt.arange(2.0)
900+
x = dpt.arange(4, dtype="i4", sycl_queue=q)
901+
ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q)
902+
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
888903

889-
with pytest.raises(ValueError):
890-
dpt.take(dpt.reshape(x, (2, 2)), ind0)
891904
with pytest.raises(TypeError):
892905
dpt.take(dict(), ind0, axis=0)
893906
with pytest.raises(TypeError):
894907
dpt.take(x, dict(), axis=0)
895908
with pytest.raises(TypeError):
909+
x[[]]
910+
with pytest.raises(IndexError):
896911
dpt.take(x, ind1, axis=0)
912+
with pytest.raises(IndexError):
913+
x[ind1]
897914

915+
with pytest.raises(ValueError):
916+
dpt.take(dpt.reshape(x, (2, 2)), ind0)
898917
with pytest.raises(ValueError):
899918
dpt.take(x, ind0, mode=0)
900919
with pytest.raises(ValueError):
901920
dpt.take(dpt.reshape(x, (2, 2)), ind0, axis=None)
902921

903922

904923
def test_put_arg_validation():
905-
get_queue_or_skip()
924+
q = get_queue_or_skip()
906925

907-
x = dpt.arange(4)
908-
ind0 = dpt.arange(2)
909-
ind1 = dpt.arange(2.0)
910-
val = dpt.asarray(2)
926+
x = dpt.arange(4, dtype="i4", sycl_queue=q)
927+
ind0 = dpt.arange(2, dtype=np.intp, sycl_queue=q)
928+
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
929+
val = dpt.asarray(2, x.dtype, sycl_queue=q)
911930

912931
with pytest.raises(TypeError):
913932
dpt.put(dict(), ind0, val, axis=0)
914933
with pytest.raises(TypeError):
915934
dpt.put(x, dict(), val, axis=0)
916935
with pytest.raises(TypeError):
936+
x[[]] = val
937+
with pytest.raises(IndexError):
917938
dpt.put(x, ind1, val, axis=0)
939+
with pytest.raises(IndexError):
940+
x[ind1] = val
918941
with pytest.raises(TypeError):
919942
dpt.put(x, ind0, dict(), axis=0)
943+
with pytest.raises(TypeError):
944+
x[ind0] = dict()
920945

921946
with pytest.raises(ValueError):
922947
dpt.put(x, ind0, val, mode=0)

0 commit comments

Comments
 (0)