Skip to content

Commit 6f6916e

Browse files
Expanded tests
1. Added test for arange with negative step 2. Added tests for empty_like, zeros_like, ones_like, full_like
1 parent 4c9d413 commit 6f6916e

File tree

2 files changed

+165
-37
lines changed

2 files changed

+165
-37
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,168 @@ def test_full_dtype_inference():
964964
assert dpt.full(10, True).dtype is np.dtype(np.bool_)
965965
assert np.issubdtype(dpt.full(10, 12.3).dtype, np.floating)
966966
assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating)
967+
968+
969+
@pytest.mark.parametrize(
970+
"dt",
971+
_all_dtypes[1:],
972+
)
973+
def test_arange(dt):
974+
try:
975+
q = dpctl.SyclQueue()
976+
except dpctl.SyclQueueCreationError:
977+
pytest.skip("Queue could not be created")
978+
979+
X = dpt.arange(0, 123, dtype=dt, sycl_queue=q)
980+
dt = np.dtype(dt)
981+
if np.issubdtype(dt, np.integer):
982+
assert int(X[47]) == 47
983+
elif np.issubdtype(dt, np.floating):
984+
assert float(X[47]) == 47.0
985+
elif np.issubdtype(dt, np.complexfloating):
986+
assert complex(X[47]) == 47.0 + 0.0j
987+
988+
X1 = dpt.arange(4, dtype=dt, sycl_queue=q)
989+
assert X1.shape == (4,)
990+
991+
X2 = dpt.arange(4, 0, -1, dtype=dt, sycl_queue=q)
992+
assert X2.shape == (4,)
993+
994+
995+
@pytest.mark.parametrize(
996+
"dt",
997+
_all_dtypes,
998+
)
999+
@pytest.mark.parametrize(
1000+
"usm_kind",
1001+
[
1002+
"shared",
1003+
"device",
1004+
"host",
1005+
],
1006+
)
1007+
def test_empty_like(dt, usm_kind):
1008+
try:
1009+
q = dpctl.SyclQueue()
1010+
except dpctl.SyclQueueCreationError:
1011+
pytest.skip("Queue could not be created")
1012+
1013+
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1014+
Y = dpt.empty_like(X)
1015+
assert X.shape == Y.shape
1016+
assert X.dtype == Y.dtype
1017+
assert X.usm_type == Y.usm_type
1018+
assert X.sycl_queue == Y.sycl_queue
1019+
1020+
X = dpt.empty(tuple(), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1021+
Y = dpt.empty_like(X)
1022+
assert X.shape == Y.shape
1023+
assert X.dtype == Y.dtype
1024+
assert X.usm_type == Y.usm_type
1025+
assert X.sycl_queue == Y.sycl_queue
1026+
1027+
1028+
@pytest.mark.parametrize(
1029+
"dt",
1030+
_all_dtypes,
1031+
)
1032+
@pytest.mark.parametrize(
1033+
"usm_kind",
1034+
[
1035+
"shared",
1036+
"device",
1037+
"host",
1038+
],
1039+
)
1040+
def test_zeros_like(dt, usm_kind):
1041+
try:
1042+
q = dpctl.SyclQueue()
1043+
except dpctl.SyclQueueCreationError:
1044+
pytest.skip("Queue could not be created")
1045+
1046+
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1047+
Y = dpt.zeros_like(X)
1048+
assert X.shape == Y.shape
1049+
assert X.dtype == Y.dtype
1050+
assert X.usm_type == Y.usm_type
1051+
assert X.sycl_queue == Y.sycl_queue
1052+
assert np.allclose(dpt.asnumpy(Y), np.zeros(X.shape, dtype=X.dtype))
1053+
1054+
X = dpt.empty(tuple(), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1055+
Y = dpt.zeros_like(X)
1056+
assert X.shape == Y.shape
1057+
assert X.dtype == Y.dtype
1058+
assert X.usm_type == Y.usm_type
1059+
assert X.sycl_queue == Y.sycl_queue
1060+
assert np.array_equal(dpt.asnumpy(Y), np.zeros(X.shape, dtype=X.dtype))
1061+
1062+
1063+
@pytest.mark.parametrize(
1064+
"dt",
1065+
_all_dtypes,
1066+
)
1067+
@pytest.mark.parametrize(
1068+
"usm_kind",
1069+
[
1070+
"shared",
1071+
"device",
1072+
"host",
1073+
],
1074+
)
1075+
def test_ones_like(dt, usm_kind):
1076+
try:
1077+
q = dpctl.SyclQueue()
1078+
except dpctl.SyclQueueCreationError:
1079+
pytest.skip("Queue could not be created")
1080+
1081+
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1082+
Y = dpt.ones_like(X)
1083+
assert X.shape == Y.shape
1084+
assert X.dtype == Y.dtype
1085+
assert X.usm_type == Y.usm_type
1086+
assert X.sycl_queue == Y.sycl_queue
1087+
assert np.allclose(dpt.asnumpy(Y), np.ones(X.shape, dtype=X.dtype))
1088+
1089+
X = dpt.empty(tuple(), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1090+
Y = dpt.ones_like(X)
1091+
assert X.shape == Y.shape
1092+
assert X.dtype == Y.dtype
1093+
assert X.usm_type == Y.usm_type
1094+
assert X.sycl_queue == Y.sycl_queue
1095+
assert np.array_equal(dpt.asnumpy(Y), np.ones(X.shape, dtype=X.dtype))
1096+
1097+
1098+
@pytest.mark.parametrize(
1099+
"dt",
1100+
_all_dtypes,
1101+
)
1102+
@pytest.mark.parametrize(
1103+
"usm_kind",
1104+
[
1105+
"shared",
1106+
"device",
1107+
"host",
1108+
],
1109+
)
1110+
def test_full_like(dt, usm_kind):
1111+
try:
1112+
q = dpctl.SyclQueue()
1113+
except dpctl.SyclQueueCreationError:
1114+
pytest.skip("Queue could not be created")
1115+
1116+
fill_v = np.dtype(dt).type(1)
1117+
X = dpt.empty((4, 5), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1118+
Y = dpt.full_like(X, fill_v)
1119+
assert X.shape == Y.shape
1120+
assert X.dtype == Y.dtype
1121+
assert X.usm_type == Y.usm_type
1122+
assert X.sycl_queue == Y.sycl_queue
1123+
assert np.allclose(dpt.asnumpy(Y), np.ones(X.shape, dtype=X.dtype))
1124+
1125+
X = dpt.empty(tuple(), dtype=dt, usm_type=usm_kind, sycl_queue=q)
1126+
Y = dpt.full_like(X, fill_v)
1127+
assert X.shape == Y.shape
1128+
assert X.dtype == Y.dtype
1129+
assert X.usm_type == Y.usm_type
1130+
assert X.sycl_queue == Y.sycl_queue
1131+
assert np.array_equal(dpt.asnumpy(Y), np.ones(X.shape, dtype=X.dtype))

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -725,40 +725,3 @@ def test_roll_2d(data):
725725
Y = dpt.roll(X, sh, ax)
726726
Ynp = np.roll(Xnp, sh, ax)
727727
assert_array_equal(Ynp, dpt.asnumpy(Y))
728-
729-
730-
@pytest.mark.parametrize(
731-
"dt",
732-
[
733-
"u1",
734-
"i1",
735-
"u2",
736-
"i2",
737-
"u4",
738-
"i4",
739-
"u8",
740-
"i8",
741-
"f2",
742-
"f4",
743-
"f8",
744-
"c8",
745-
"c16",
746-
],
747-
)
748-
def test_arange(dt):
749-
try:
750-
q = dpctl.SyclQueue()
751-
except dpctl.SyclQueueCreationError:
752-
pytest.skip("Queue could not be created")
753-
754-
X = dpt.arange(0, 123, dtype=dt, sycl_queue=q)
755-
dt = np.dtype(dt)
756-
if np.issubdtype(dt, np.integer):
757-
assert int(X[47]) == 47
758-
elif np.issubdtype(dt, np.floating):
759-
assert float(X[47]) == 47.0
760-
elif np.issubdtype(dt, np.complexfloating):
761-
assert complex(X[47]) == 47.0 + 0.0j
762-
763-
X1 = dpt.arange(4, dtype=dt, sycl_queue=q)
764-
assert X1.shape == (4,)

0 commit comments

Comments
 (0)