Skip to content

Commit d3e41ec

Browse files
committed
dptcl.tensor.eye dtype test and order validation
1 parent b4b233a commit d3e41ec

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

dpctl/tensor/_ctors.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,11 +1087,15 @@ def eye(
10871087
if n_cols is None:
10881088
n_cols = n_rows
10891089
# allocate a 1D array of zeros, length equal to n_cols * n_rows
1090-
k_dt = type(k)
1091-
if not np.issubdtype(k_dt, np.integer):
1092-
raise TypeError(
1093-
"k keyword must be an integer, got {type}".format(type=k_dt)
1090+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
1091+
raise ValueError(
1092+
"Unrecognized order keyword value, expecting 'F' or 'C'."
10941093
)
1094+
else:
1095+
order = order[0].upper()
1096+
n_rows = operator.index(n_rows)
1097+
n_cols = operator.index(n_cols)
1098+
k = operator.index(k)
10951099
x = zeros(
10961100
(n_rows * n_cols,),
10971101
dtype=dtype,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,24 +1288,25 @@ def test_common_arg_validation():
12881288
dpt.ones_like(X)
12891289
with pytest.raises(TypeError):
12901290
dpt.full_like(X, 1)
1291-
with pytest.raises(TypeError):
1292-
dpt.eye(4, k=1.2)
12931291

12941292

1295-
@pytest.mark.parametrize("shapes", [(0,), (1,), (7,), (6, 1), (3, 9), (10, 5)])
1296-
@pytest.mark.parametrize("k", np.arange(-4, 5, 1))
1297-
@pytest.mark.parametrize("orders", ["C", "F"])
1298-
def test_eye(shapes, k, orders):
1293+
@pytest.mark.parametrize("dtype", _all_dtypes)
1294+
def test_eye(dtype):
1295+
X = dpt.eye(4, 5, dtype=dtype)
1296+
Xnp = np.eye(4, 5, dtype=dtype)
1297+
assert X.dtype == Xnp.dtype
1298+
assert np.array_equal(Xnp, dpt.asnumpy(X))
1299+
1300+
1301+
@pytest.mark.parametrize("shape", [(7,), (6, 1), (10, 5), (3, 9)])
1302+
@pytest.mark.parametrize("k", np.arange(-2, 2, 1))
1303+
@pytest.mark.parametrize("order", ["C", "F"])
1304+
def test_eye_shapes(shape, k, order):
12991305
try:
13001306
q = dpctl.SyclQueue()
13011307
except dpctl.SyclQueueCreationError:
13021308
pytest.skip("Queue could not be created")
1303-
1304-
shape = shapes
1305-
k = k
1306-
order = orders
1307-
13081309
Xnp = np.eye(*shape, k=k, order=order)
13091310
X = dpt.eye(*shape, k=k, order=order, sycl_queue=q)
13101311

1311-
np.testing.assert_array_equal(Xnp, dpt.asnumpy(X))
1312+
assert np.array_equal(Xnp, dpt.asnumpy(X))

0 commit comments

Comments
 (0)