Skip to content

Commit 9be7162

Browse files
committed
Rewrote test for dptcl.tensor.eye
1 parent 93bb3e1 commit 9be7162

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,24 @@ def test_full_like(dt, usm_kind):
12561256
assert np.array_equal(dpt.asnumpy(Y), np.ones(X.shape, dtype=X.dtype))
12571257

12581258

1259+
@pytest.mark.parametrize("dtype", _all_dtypes)
1260+
@pytest.mark.parametrize("usm_kind", ["shared", "device", "host"])
1261+
def test_eye(dtype, usm_kind):
1262+
try:
1263+
q = dpctl.SyclQueue()
1264+
except dpctl.SyclQueueCreationError:
1265+
pytest.skip("Queue could not be created")
1266+
1267+
if dtype in ["f8", "c16"] and q.sycl_device.has_aspect_fp64 is False:
1268+
pytest.skip(
1269+
"Device does not support double precision floating point type"
1270+
)
1271+
X = dpt.eye(4, 5, k=1, dtype=dtype, usm_type=usm_kind, sycl_queue=q)
1272+
Xnp = np.eye(4, 5, k=1, dtype=dtype)
1273+
assert X.dtype == Xnp.dtype
1274+
assert np.array_equal(Xnp, dpt.asnumpy(X))
1275+
1276+
12591277
def test_common_arg_validation():
12601278
order = "I"
12611279
# invalid order must raise ValueError
@@ -1267,6 +1285,8 @@ def test_common_arg_validation():
12671285
dpt.ones(10, order=order)
12681286
with pytest.raises(ValueError):
12691287
dpt.full(10, 1, order=order)
1288+
with pytest.raises(ValueError):
1289+
dpt.eye(10, order=order)
12701290
X = dpt.empty(10)
12711291
with pytest.raises(ValueError):
12721292
dpt.empty_like(X, order=order)
@@ -1288,25 +1308,3 @@ def test_common_arg_validation():
12881308
dpt.ones_like(X)
12891309
with pytest.raises(TypeError):
12901310
dpt.full_like(X, 1)
1291-
1292-
1293-
@pytest.mark.parametrize("dtype", _all_dtypes)
1294-
def test_eye(dtype):
1295-
X = dpt.eye(4, 5, dtype=dtype)
1296-
Xnp = np.eye(4, 5, dtype=dtype)
1297-
assert X.dtype == Xnp.dtype
1298-
assert np.array_equal(Xnp, dpt.asnumpy(X))
1299-
1300-
1301-
@pytest.mark.parametrize("shape", [(7,), (6, 1), (10, 5), (3, 9)])
1302-
@pytest.mark.parametrize("k", np.arange(-2, 2, 1))
1303-
@pytest.mark.parametrize("order", ["C", "F"])
1304-
def test_eye_shapes(shape, k, order):
1305-
try:
1306-
q = dpctl.SyclQueue()
1307-
except dpctl.SyclQueueCreationError:
1308-
pytest.skip("Queue could not be created")
1309-
Xnp = np.eye(*shape, k=k, order=order)
1310-
X = dpt.eye(*shape, k=k, order=order, sycl_queue=q)
1311-
1312-
assert np.array_equal(Xnp, dpt.asnumpy(X))

0 commit comments

Comments
 (0)