Skip to content

Commit d7fc400

Browse files
Extended _basic_slice_meta to process advanced indexing specs
1 parent 402c1d6 commit d7fc400

File tree

1 file changed

+97
-27
lines changed

1 file changed

+97
-27
lines changed

dpctl/tensor/_slicing.pxi

Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,41 @@ cdef Py_ssize_t _slice_len(
3434
return 1 + ((sl_stop - sl_start + 1) // sl_step)
3535

3636

37-
cdef object _basic_slice_meta(object ind, tuple shape,
38-
tuple strides, Py_ssize_t offset):
37+
cdef bint _is_integral(object x) except *:
38+
"""Gives True if x is an integral slice spec"""
39+
if isinstance(x, (int, numbers.Integral)):
40+
return True
41+
if isinstance(x, usm_ndarray):
42+
if x.ndim > 0:
43+
return False
44+
if x.dtype.kind not in "ui":
45+
return False
46+
return True
47+
if callable(getattr(x, "__index__", None)):
48+
try:
49+
x.__index__()
50+
except (TypeError, ValueError):
51+
return False
52+
return True
53+
return False
54+
55+
56+
def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
3957
"""
4058
Give basic slicing index `ind` and array layout information produce
41-
a tuple (resulting_shape, resulting_strides, resulting_offset)
42-
used to contruct a view into underlying array.
59+
a 5-tuple (resulting_shape, resulting_strides, resulting_offset,
60+
advanced_ind, resulting_advanced_ind_pos)
61+
used to contruct a view into underlying array over which advanced
62+
indexing, if any, is to be performed.
4363
44-
Raises IndexError for invalid index `ind`, and NotImplementedError
45-
if `ind` is an array.
64+
Raises IndexError for invalid index `ind`.
4665
"""
47-
is_integral = lambda x: (
48-
isinstance(x, numbers.Integral) or callable(getattr(x, "__index__", None))
49-
)
66+
_no_advanced_ind = tuple()
67+
_no_advanced_pos = -1
5068
if ind is Ellipsis:
51-
return (shape, strides, offset)
69+
return (shape, strides, offset, _no_advanced_ind, _no_advanced_pos)
5270
elif ind is None:
53-
return ((1,) + shape, (0,) + strides, offset)
71+
return ((1,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
5472
elif isinstance(ind, slice):
5573
sl_start, sl_stop, sl_step = ind.indices(shape[0])
5674
sh0 = _slice_len(sl_start, sl_stop, sl_step)
@@ -60,38 +78,70 @@ cdef object _basic_slice_meta(object ind, tuple shape,
6078
return (
6179
(sh0, ) + shape[1:],
6280
new_strides,
63-
new_offset
81+
new_offset,
82+
_no_advanced_ind,
83+
_no_advanced_pos
6484
)
65-
elif is_integral(ind):
85+
elif _is_integral(ind):
6686
ind = ind.__index__()
6787
if 0 <= ind < shape[0]:
68-
return (shape[1:], strides[1:], offset + ind * strides[0])
88+
return (shape[1:], strides[1:], offset + ind * strides[0], _no_advanced_ind, _no_advanced_pos)
6989
elif -shape[0] <= ind < 0:
7090
return (shape[1:], strides[1:],
71-
offset + (shape[0] + ind) * strides[0])
91+
offset + (shape[0] + ind) * strides[0], _no_advanced_ind, _no_advanced_pos)
7292
else:
7393
raise IndexError(
7494
"Index {0} is out of range for axes 0 with "
7595
"size {1}".format(ind, shape[0]))
76-
elif isinstance(ind, list):
77-
raise NotImplemented
96+
elif isinstance(ind, usm_ndarray):
97+
return (shape, strides, 0, (ind,), 0)
7898
elif isinstance(ind, tuple):
7999
axes_referenced = 0
80100
ellipses_count = 0
81101
newaxis_count = 0
82102
explicit_index = 0
103+
array_count = 0
104+
seen_arrays_yet = False
105+
array_streak_started = False
106+
array_streak_interrupted = False
83107
for i in ind:
84108
if i is None:
85-
newaxis_count = newaxis_count + 1
109+
newaxis_count += 1
110+
if array_streak_started:
111+
array_streak_interrupted = True
86112
elif i is Ellipsis:
87-
ellipses_count = ellipses_count + 1
113+
ellipses_count += 1
114+
if array_streak_started:
115+
array_streak_interrupted = True
88116
elif isinstance(i, slice):
89-
axes_referenced = axes_referenced + 1
90-
elif is_integral(i):
91-
explicit_index = explicit_index + 1
92-
axes_referenced = axes_referenced + 1
93-
elif isinstance(i, list):
94-
raise NotImplemented
117+
axes_referenced += 1
118+
if array_streak_started:
119+
array_streak_interrupted = True
120+
elif _is_integral(i):
121+
explicit_index += 1
122+
axes_referenced += 1
123+
if array_streak_started:
124+
array_streak_interrupted = True
125+
elif isinstance(i, usm_ndarray):
126+
if not seen_arrays_yet:
127+
seen_arrays_yet = True
128+
array_streak_started = True
129+
array_streak_interrupted = False
130+
if array_streak_interrupted:
131+
raise IndexError(
132+
"Advanced indexing array specs may not be "
133+
"separated by basic slicing specs."
134+
)
135+
dt_k = i.dtype.kind
136+
if dt_k == "b":
137+
axes_referenced += i.ndim
138+
elif dt_k in "ui":
139+
axes_referenced += 1
140+
else:
141+
raise IndexError(
142+
"arrays used as indices must be of integer (or boolean) type"
143+
)
144+
array_count += 1
95145
else:
96146
raise TypeError
97147
if ellipses_count > 1:
@@ -108,7 +158,10 @@ cdef object _basic_slice_meta(object ind, tuple shape,
108158
+ axes_referenced - explicit_index)
109159
new_shape = list()
110160
new_strides = list()
161+
new_advanced_ind = list()
111162
k = 0
163+
new_advanced_start_pos = -1
164+
advanced_start_pos_set = False
112165
new_offset = offset
113166
is_empty = False
114167
for i in range(len(ind)):
@@ -133,7 +186,7 @@ cdef object _basic_slice_meta(object ind, tuple shape,
133186
if sh_i == 0:
134187
is_empty = True
135188
k = k_new
136-
elif is_integral(ind_i):
189+
elif _is_integral(ind_i):
137190
ind_i = ind_i.__index__()
138191
if 0 <= ind_i < shape[k]:
139192
k_new = k + 1
@@ -149,8 +202,25 @@ cdef object _basic_slice_meta(object ind, tuple shape,
149202
raise IndexError(
150203
("Index {0} is out of range for "
151204
"axes {1} with size {2}").format(ind_i, k, shape[k]))
205+
elif isinstance(ind_i, usm_ndarray):
206+
if not advanced_start_pos_set:
207+
new_advanced_start_pos = len(new_shape)
208+
advanced_start_pos_set = True
209+
new_advanced_ind.append(ind_i)
210+
dt_k = ind_i.dtype.kind
211+
if dt_k == "b":
212+
k_new = k + ind_i.ndim
213+
else:
214+
k_new = k + 1
215+
new_shape.extend(shape[k:k_new])
216+
new_strides.extend(strides[k:k_new])
217+
k = k_new
152218
new_shape.extend(shape[k:])
153219
new_strides.extend(strides[k:])
154-
return (tuple(new_shape), tuple(new_strides), new_offset)
220+
new_shape_len += len(shape) - k
221+
# assert len(new_shape) == new_shape_len, f"{len(new_shape)} vs {new_shape_len}"
222+
# assert len(new_strides) == new_shape_len, f"{len(new_strides)} vs {new_shape_len}"
223+
# assert len(new_advanced_ind) == array_count
224+
return (tuple(new_shape), tuple(new_strides), new_offset, tuple(new_advanced_ind), new_advanced_start_pos)
155225
else:
156226
raise TypeError

0 commit comments

Comments
 (0)