Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,11 @@ cdef class usm_ndarray:

from ._reshape import reshaped_strides

new_nd = len(new_shape)
try:
new_nd = len(new_shape)
except TypeError:
new_nd = 1
new_shape = (new_shape,)
try:
new_shape = tuple(operator.index(dim) for dim in new_shape)
except TypeError:
Expand Down
43 changes: 21 additions & 22 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ctypes
import numbers
from math import prod

import numpy as np
import pytest
Expand Down Expand Up @@ -1102,7 +1103,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
skip_if_dtype_not_supported(dtype, q)
shape = (2, 4, 3)
Xnp = (
np.random.randint(-10, 10, size=np.prod(shape))
np.random.randint(-10, 10, size=prod(shape))
.astype(dtype)
.reshape(shape)
)
Expand Down Expand Up @@ -1307,6 +1308,10 @@ def relaxed_strides_equal(st1, st2, sh):
X = dpt.usm_ndarray(sh_s, dtype="?")
X.shape = sh_f
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
sz = X.size
X.shape = sz
assert X.shape == (sz,)
assert relaxed_strides_equal(X.strides, (1,), (sz,))

X = dpt.usm_ndarray(sh_s, dtype="u4")
with pytest.raises(TypeError):
Expand Down Expand Up @@ -2077,11 +2082,9 @@ def test_tril(dtype):
skip_if_dtype_not_supported(dtype, q)

shape = (2, 3, 4, 5, 5)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
)
X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape)
Y = dpt.tril(X)
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape)
Ynp = np.tril(Xnp)
assert Y.dtype == Ynp.dtype
assert np.array_equal(Ynp, dpt.asnumpy(Y))
Expand All @@ -2093,11 +2096,9 @@ def test_triu(dtype):
skip_if_dtype_not_supported(dtype, q)

shape = (4, 5)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
)
X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape)
Y = dpt.triu(X, k=1)
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape)
Ynp = np.triu(Xnp, k=1)
assert Y.dtype == Ynp.dtype
assert np.array_equal(Ynp, dpt.asnumpy(Y))
Expand All @@ -2110,7 +2111,7 @@ def test_tri_usm_type(tri_fn, usm_type):
dtype = dpt.uint16

shape = (2, 3, 4, 5, 5)
size = np.prod(shape)
size = prod(shape)
X = dpt.reshape(
dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape
)
Expand All @@ -2129,11 +2130,11 @@ def test_tril_slice():
q = get_queue_or_skip()

shape = (6, 10)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
)[1:, ::-2]
X = dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape)[
1:, ::-2
]
Y = dpt.tril(X)
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape)[1:, ::-2]
Xnp = np.arange(prod(shape), dtype="int").reshape(shape)[1:, ::-2]
Ynp = np.tril(Xnp)
assert Y.dtype == Ynp.dtype
assert np.array_equal(Ynp, dpt.asnumpy(Y))
Expand All @@ -2144,14 +2145,12 @@ def test_triu_permute_dims():

shape = (2, 3, 4, 5)
X = dpt.permute_dims(
dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
),
dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape),
(3, 2, 1, 0),
)
Y = dpt.triu(X)
Xnp = np.transpose(
np.arange(np.prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
np.arange(prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
)
Ynp = np.triu(Xnp)
assert Y.dtype == Ynp.dtype
Expand Down Expand Up @@ -2189,12 +2188,12 @@ def test_triu_order_k(order, k):

shape = (3, 3)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
dpt.arange(prod(shape), dtype="int", sycl_queue=q),
shape,
order=order,
)
Y = dpt.triu(X, k=k)
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order)
Ynp = np.triu(Xnp, k=k)
assert Y.dtype == Ynp.dtype
assert X.flags == Y.flags
Expand All @@ -2210,12 +2209,12 @@ def test_tril_order_k(order, k):
pytest.skip("Queue could not be created")
shape = (3, 3)
X = dpt.reshape(
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
dpt.arange(prod(shape), dtype="int", sycl_queue=q),
shape,
order=order,
)
Y = dpt.tril(X, k=k)
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order)
Ynp = np.tril(Xnp, k=k)
assert Y.dtype == Ynp.dtype
assert X.flags == Y.flags
Expand Down
Loading