Skip to content

Commit f75723b

Browse files
Added tests to test_usm_ndarray_indexing
1 parent 19691ca commit f75723b

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,223 @@ def test_advanced_indexing_compute_follows_data():
970970
dpt.put(x, ind0, val1, axis=0)
971971
with pytest.raises(ExecutionPlacementError):
972972
x[ind0] = val1
973+
974+
975+
#######
976+
977+
978+
def test_extract_all_1d():
979+
x = dpt.arange(30, dtype="i4")
980+
sel = dpt.ones(30, dtype="?")
981+
sel[::2] = False
982+
983+
res = x[sel]
984+
expected_res = dpt.asnumpy(x)[dpt.asnumpy(sel)]
985+
assert (dpt.asnumpy(res) == expected_res).all()
986+
987+
res2 = dpt.extract(sel, x)
988+
assert (dpt.asnumpy(res2) == expected_res).all()
989+
990+
991+
def test_extract_all_2d():
992+
x = dpt.reshape(dpt.arange(30, dtype="i4"), (5, 6))
993+
sel = dpt.ones(30, dtype="?")
994+
sel[::2] = False
995+
sel = dpt.reshape(sel, x.shape)
996+
997+
res = x[sel]
998+
expected_res = dpt.asnumpy(x)[dpt.asnumpy(sel)]
999+
assert (dpt.asnumpy(res) == expected_res).all()
1000+
1001+
res2 = dpt.extract(sel, x)
1002+
assert (dpt.asnumpy(res2) == expected_res).all()
1003+
1004+
1005+
def test_extract_2D_axis0():
1006+
x = dpt.reshape(dpt.arange(30, dtype="i4"), (5, 6))
1007+
sel = dpt.ones(x.shape[0], dtype="?")
1008+
sel[::2] = False
1009+
1010+
res = x[sel]
1011+
expected_res = dpt.asnumpy(x)[dpt.asnumpy(sel)]
1012+
assert (dpt.asnumpy(res) == expected_res).all()
1013+
1014+
1015+
def test_extract_2D_axis1():
1016+
x = dpt.reshape(dpt.arange(30, dtype="i4"), (5, 6))
1017+
sel = dpt.ones(x.shape[1], dtype="?")
1018+
sel[::2] = False
1019+
1020+
res = x[:, sel]
1021+
expected = dpt.asnumpy(x)[:, dpt.asnumpy(sel)]
1022+
assert (dpt.asnumpy(res) == expected).all()
1023+
1024+
1025+
def test_extract_begin():
1026+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1027+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1028+
sel = dpt.zeros((3, 3), dtype="?")
1029+
sel[0, 0] = True
1030+
sel[1, 1] = True
1031+
z = y[sel]
1032+
expected = dpt.asnumpy(y)[[0, 1], [0, 1]]
1033+
assert (dpt.asnumpy(z) == expected).all()
1034+
1035+
1036+
def test_extract_end():
1037+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1038+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1039+
sel = dpt.zeros((4, 4), dtype="?")
1040+
sel[0, 0] = True
1041+
z = y[..., sel]
1042+
expected = dpt.asnumpy(y)[..., [0], [0]]
1043+
assert (dpt.asnumpy(z) == expected).all()
1044+
1045+
1046+
def test_extract_middle():
1047+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1048+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1049+
sel = dpt.zeros((3, 4), dtype="?")
1050+
sel[0, 0] = True
1051+
z = y[:, sel]
1052+
expected = dpt.asnumpy(y)[:, [0], [0], :]
1053+
assert (dpt.asnumpy(z) == expected).all()
1054+
1055+
1056+
def test_extract_empty_result():
1057+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1058+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1059+
sel = dpt.zeros((3, 4), dtype="?")
1060+
z = y[:, sel]
1061+
assert z.shape == (
1062+
y.shape[0],
1063+
0,
1064+
y.shape[3],
1065+
)
1066+
1067+
1068+
def test_place_all_1d():
1069+
x = dpt.arange(10, dtype="i2")
1070+
sel = dpt.zeros(10, dtype="?")
1071+
sel[0::2] = True
1072+
val = dpt.zeros(5, dtype=x.dtype)
1073+
x[sel] = val
1074+
assert (dpt.asnumpy(x) == np.array([0, 1, 0, 3, 0, 5, 0, 7, 0, 9])).all()
1075+
dpt.place(x, sel, dpt.asarray(2))
1076+
assert (dpt.asnumpy(x) == np.array([2, 1, 2, 3, 2, 5, 2, 7, 2, 9])).all()
1077+
1078+
1079+
def test_place_2d_axis0():
1080+
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
1081+
sel = dpt.asarray([True, False, True])
1082+
val = dpt.zeros((2, 4), dtype=x.dtype)
1083+
x[sel] = val
1084+
expected_x = np.stack(
1085+
(
1086+
np.zeros(4, dtype="i2"),
1087+
np.arange(4, 8, dtype="i2"),
1088+
np.zeros(4, dtype="i2"),
1089+
)
1090+
)
1091+
assert (dpt.asnumpy(x) == expected_x).all()
1092+
1093+
1094+
def test_place_2d_axis1():
1095+
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
1096+
sel = dpt.asarray([True, False, True, False])
1097+
val = dpt.zeros((3, 2), dtype=x.dtype)
1098+
x[:, sel] = val
1099+
expected_x = np.array(
1100+
[[0, 1, 0, 3], [0, 5, 0, 7], [0, 9, 0, 11]], dtype="i2"
1101+
)
1102+
assert (dpt.asnumpy(x) == expected_x).all()
1103+
1104+
1105+
def test_place_2d_axis1_scalar():
1106+
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
1107+
sel = dpt.asarray([True, False, True, False])
1108+
val = dpt.zeros(tuple(), dtype=x.dtype)
1109+
x[:, sel] = val
1110+
expected_x = np.array(
1111+
[[0, 1, 0, 3], [0, 5, 0, 7], [0, 9, 0, 11]], dtype="i2"
1112+
)
1113+
assert (dpt.asnumpy(x) == expected_x).all()
1114+
1115+
1116+
def test_place_all_slices():
1117+
x = dpt.reshape(dpt.arange(12, dtype="i2"), (3, 4))
1118+
sel = dpt.asarray(
1119+
[
1120+
[False, True, True, False],
1121+
[True, True, False, False],
1122+
[False, False, True, True],
1123+
],
1124+
dtype="?",
1125+
)
1126+
y = dpt.ones_like(x)
1127+
y[sel] = x[sel]
1128+
1129+
1130+
def test_place_some_slices_begin():
1131+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1132+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1133+
sel = dpt.zeros((3, 3), dtype="?")
1134+
sel[0, 0] = True
1135+
sel[1, 1] = True
1136+
z = y[sel]
1137+
w = dpt.zeros_like(y)
1138+
w[sel] = z
1139+
1140+
1141+
def test_place_some_slices_mid():
1142+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1143+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1144+
sel = dpt.zeros((3, 4), dtype="?")
1145+
sel[0, 0] = True
1146+
sel[1, 1] = True
1147+
z = y[:, sel]
1148+
w = dpt.zeros_like(y)
1149+
w[:, sel] = z
1150+
1151+
1152+
def test_place_some_slices_end():
1153+
x = dpt.reshape(dpt.arange(3 * 3 * 4 * 4, dtype="i2"), (3, 4, 3, 4))
1154+
y = dpt.permute_dims(x, (2, 0, 3, 1))
1155+
sel = dpt.zeros((4, 4), dtype="?")
1156+
sel[0, 0] = True
1157+
sel[1, 1] = True
1158+
z = y[:, :, sel]
1159+
w = dpt.zeros_like(y)
1160+
w[:, :, sel] = z
1161+
1162+
1163+
def test_place_cycling():
1164+
x = dpt.zeros(10, dtype="f4")
1165+
y = dpt.asarray([2, 3])
1166+
sel = dpt.ones(x.size, dtype="?")
1167+
dpt.place(x, sel, y)
1168+
expected = np.array(
1169+
[
1170+
2,
1171+
3,
1172+
]
1173+
* 5,
1174+
dtype=x.dtype,
1175+
)
1176+
assert (dpt.asnumpy(x) == expected).all()
1177+
1178+
1179+
def test_place_subset():
1180+
x = dpt.zeros(10, dtype="f4")
1181+
y = dpt.ones_like(x)
1182+
sel = dpt.ones(x.size, dtype="?")
1183+
sel[::2] = False
1184+
dpt.place(x, sel, y)
1185+
expected = np.array([1, 3, 5, 7, 9], dtype=x.dtype)
1186+
assert (dpt.asnumpy(x) == expected).all()
1187+
1188+
1189+
def test_nonzero():
1190+
x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))
1191+
(i,) = dpt.nonzero(x)
1192+
assert dpt.asnumpy(i) == np.array([3, 4, 5, 6]).all()

0 commit comments

Comments
 (0)