Skip to content

Commit 101bcc2

Browse files
Change to support use of usm_ndarray scalars in slicing
Added tests to that effect.
1 parent e32b395 commit 101bcc2

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

dpctl/tensor/_slicing.pxi

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ cdef object _basic_slice_meta(object ind, tuple shape,
4444
Raises IndexError for invalid index `ind`, and NotImplementedError
4545
if `ind` is an array.
4646
"""
47+
is_integral = lambda x: (
48+
isinstance(x, numbers.Integral) or callable(getattr(x, "__index__", None))
49+
)
4750
if ind is Ellipsis:
4851
return (shape, strides, offset)
4952
elif ind is None:
@@ -58,7 +61,8 @@ cdef object _basic_slice_meta(object ind, tuple shape,
5861
new_strides,
5962
offset + sl_start * strides[0]
6063
)
61-
elif isinstance(ind, numbers.Integral):
64+
elif is_integral(ind):
65+
ind = ind.__index__()
6266
if 0 <= ind < shape[0]:
6367
return (shape[1:], strides[1:], offset + ind * strides[0])
6468
elif -shape[0] <= ind < 0:
@@ -82,7 +86,7 @@ cdef object _basic_slice_meta(object ind, tuple shape,
8286
ellipses_count = ellipses_count + 1
8387
elif isinstance(i, slice):
8488
axes_referenced = axes_referenced + 1
85-
elif isinstance(i, numbers.Integral):
89+
elif is_integral(i):
8690
explicit_index = explicit_index + 1
8791
axes_referenced = axes_referenced + 1
8892
elif isinstance(i, list):
@@ -124,7 +128,8 @@ cdef object _basic_slice_meta(object ind, tuple shape,
124128
new_strides.append(str_i)
125129
new_offset = new_offset + sl_start * strides[k]
126130
k = k_new
127-
elif isinstance(ind_i, numbers.Integral):
131+
elif is_integral(ind_i):
132+
ind_i = ind_i.__index__()
128133
if 0 <= ind_i < shape[k]:
129134
k_new = k + 1
130135
new_offset = new_offset + ind_i * strides[k]

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ def test_slicing_basic():
307307
Xusm[:, -128]
308308
with pytest.raises(TypeError):
309309
Xusm[{1, 2, 3, 4, 5, 6, 7}]
310+
X = dpt.usm_ndarray(10, "u1")
311+
X.usm_data.copy_from_host(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09")
312+
int(
313+
X[X[2]]
314+
) # check that objects with __index__ method can be used as indices
315+
Xh = dpm.as_usm_memory(X[X[2] : X[5]]).copy_to_host()
316+
Xnp = np.arange(0, 10, dtype="u1")
317+
assert np.array_equal(Xh, Xnp[Xnp[2] : Xnp[5]])
310318

311319

312320
def test_ctor_invalid_shape():

0 commit comments

Comments
 (0)