Skip to content

Commit 7e1b918

Browse files
committed
Added dpt.eye test, type error for k keyword
1 parent 3203b9c commit 7e1b918

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,11 +1083,16 @@ def eye(
10831083
if n_cols is None:
10841084
n_cols = n_rows
10851085
#allocate a 1D array of zeros, length equal to n_cols * n_rows
1086+
k_dt = type(k)
1087+
if not np.issubdtype(k_dt, np.integer):
1088+
raise TypeError(
1089+
"k keyword must be an integer, got {type}".format(type=k_dt)
1090+
)
10861091
x = zeros((n_rows * n_cols,), dtype=dtype, order=order, device=device, usm_type=usm_type, sycl_queue=sycl_queue)
10871092
if k > -n_rows and k < n_cols:
1088-
#find the length of an arbitrary diagonal
1093+
#find the length of the diagonal
10891094
l = min(n_cols, n_rows, n_cols-k, n_rows+k)
1090-
#i is the first element of the diagonal, j is the last, s is the step size
1095+
#i is the first index of the diagonal in 1D index space, j is the last, s is the step size
10911096
if order == "C":
10921097
s = n_cols+1
10931098
i = k if k >= 0 else n_cols*-k

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import ctypes
1818
import numbers
19+
from typing import Type
1920

2021
import numpy as np
2122
import pytest
@@ -1276,6 +1277,8 @@ def test_common_arg_validation():
12761277
dpt.ones_like(X, order=order)
12771278
with pytest.raises(ValueError):
12781279
dpt.full_like(X, 1, order=order)
1280+
with pytest.raises(ValueError):
1281+
dpt.eye(4, order=order)
12791282
X = dict()
12801283
# test for type validation
12811284
with pytest.raises(TypeError):
@@ -1286,3 +1289,23 @@ def test_common_arg_validation():
12861289
dpt.ones_like(X)
12871290
with pytest.raises(TypeError):
12881291
dpt.full_like(X, 1)
1292+
with pytest.raises(TypeError):
1293+
dpt.eye(4, k=1.2)
1294+
1295+
@pytest.mark.parametrize("shapes", [(0,), (1,), (7,), (6, 1), (3, 9), (10,5)])
1296+
@pytest.mark.parametrize("k", np.arange(-4, 5, 1))
1297+
@pytest.mark.parametrize("orders", ["C", "F"])
1298+
def test_eye(shapes, k, orders):
1299+
try:
1300+
q = dpctl.SyclQueue()
1301+
except dpctl.SyclQueueCreationError:
1302+
pytest.skip("Queue could not be created")
1303+
1304+
shape=shapes
1305+
k=k
1306+
order=orders
1307+
1308+
Xnp = np.eye(*shape, k=k, order=order)
1309+
X = dpt.eye(*shape, k=k, order=order, sycl_queue=q)
1310+
1311+
np.testing.assert_array_equal(Xnp, dpt.asnumpy(X))

0 commit comments

Comments
 (0)