Skip to content

Commit 11924c0

Browse files
Extended test_usm_ndarray_dlpack with example from gh-1071
1 parent 8b12038 commit 11924c0

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,30 @@ def test_from_dlpack(shape, typestr, usm_type):
113113
assert V.strides == W.strides
114114

115115

116+
@pytest.mark.parametrize("mod", [2, 5])
117+
def test_from_dlpack_strides(mod, typestr, usm_type):
118+
all_root_devices = dpctl.get_devices()
119+
for sycl_dev in all_root_devices:
120+
skip_if_dtype_not_supported(typestr, sycl_dev)
121+
X0 = dpt.empty(
122+
3 * mod, dtype=typestr, usm_type=usm_type, device=sycl_dev
123+
)
124+
for start in range(mod):
125+
X = X0[slice(-start - 1, None, -mod)]
126+
Y = dpt.from_dlpack(X)
127+
assert X.shape == Y.shape
128+
assert X.dtype == Y.dtype or (
129+
str(X.dtype) == "bool" and str(Y.dtype) == "uint8"
130+
)
131+
assert X.sycl_device == Y.sycl_device
132+
assert X.usm_type == Y.usm_type
133+
assert X._pointer == Y._pointer
134+
if Y.ndim:
135+
V = Y[::-1]
136+
W = dpt.from_dlpack(V)
137+
assert V.strides == W.strides
138+
139+
116140
def test_from_dlpack_input_validation():
117141
vstr = dpt._dlpack.get_build_dlpack_version()
118142
assert type(vstr) is str

0 commit comments

Comments
 (0)