@@ -34,23 +34,41 @@ cdef Py_ssize_t _slice_len(
34
34
return 1 + ((sl_stop - sl_start + 1 ) // sl_step)
35
35
36
36
37
- cdef object _basic_slice_meta(object ind, tuple shape,
38
- tuple strides, Py_ssize_t offset):
37
+ cdef bint _is_integral(object x) except * :
38
+ """ Gives True if x is an integral slice spec"""
39
+ if isinstance (x, (int , numbers.Integral)):
40
+ return True
41
+ if isinstance (x, usm_ndarray):
42
+ if x.ndim > 0 :
43
+ return False
44
+ if x.dtype.kind not in " ui" :
45
+ return False
46
+ return True
47
+ if callable (getattr (x, " __index__" , None )):
48
+ try :
49
+ x.__index__()
50
+ except (TypeError , ValueError ):
51
+ return False
52
+ return True
53
+ return False
54
+
55
+
56
+ def _basic_slice_meta (ind , shape : tuple , strides : tuple , offset : int ):
39
57
"""
40
58
Give basic slicing index `ind` and array layout information produce
41
- a tuple (resulting_shape, resulting_strides, resulting_offset)
42
- used to contruct a view into underlying array.
59
+ a 5-tuple (resulting_shape, resulting_strides, resulting_offset,
60
+ advanced_ind, resulting_advanced_ind_pos)
61
+ used to contruct a view into underlying array over which advanced
62
+ indexing, if any, is to be performed.
43
63
44
- Raises IndexError for invalid index `ind`, and NotImplementedError
45
- if `ind` is an array.
64
+ Raises IndexError for invalid index `ind`.
46
65
"""
47
- is_integral = lambda x : (
48
- isinstance (x, numbers.Integral) or callable (getattr (x, " __index__" , None ))
49
- )
66
+ _no_advanced_ind = tuple ()
67
+ _no_advanced_pos = - 1
50
68
if ind is Ellipsis :
51
- return (shape, strides, offset)
69
+ return (shape, strides, offset, _no_advanced_ind, _no_advanced_pos )
52
70
elif ind is None :
53
- return ((1 ,) + shape, (0 ,) + strides, offset)
71
+ return ((1 ,) + shape, (0 ,) + strides, offset, _no_advanced_ind, _no_advanced_pos )
54
72
elif isinstance (ind, slice ):
55
73
sl_start, sl_stop, sl_step = ind.indices(shape[0 ])
56
74
sh0 = _slice_len(sl_start, sl_stop, sl_step)
@@ -60,38 +78,70 @@ cdef object _basic_slice_meta(object ind, tuple shape,
60
78
return (
61
79
(sh0, ) + shape[1 :],
62
80
new_strides,
63
- new_offset
81
+ new_offset,
82
+ _no_advanced_ind,
83
+ _no_advanced_pos
64
84
)
65
- elif is_integral (ind):
85
+ elif _is_integral (ind):
66
86
ind = ind.__index__()
67
87
if 0 <= ind < shape[0 ]:
68
- return (shape[1 :], strides[1 :], offset + ind * strides[0 ])
88
+ return (shape[1 :], strides[1 :], offset + ind * strides[0 ], _no_advanced_ind, _no_advanced_pos )
69
89
elif - shape[0 ] <= ind < 0 :
70
90
return (shape[1 :], strides[1 :],
71
- offset + (shape[0 ] + ind) * strides[0 ])
91
+ offset + (shape[0 ] + ind) * strides[0 ], _no_advanced_ind, _no_advanced_pos )
72
92
else :
73
93
raise IndexError (
74
94
" Index {0} is out of range for axes 0 with "
75
95
" size {1}" .format(ind, shape[0 ]))
76
- elif isinstance (ind, list ):
77
- raise NotImplemented
96
+ elif isinstance (ind, usm_ndarray ):
97
+ return (shape, strides, 0 , (ind,), 0 )
78
98
elif isinstance (ind, tuple ):
79
99
axes_referenced = 0
80
100
ellipses_count = 0
81
101
newaxis_count = 0
82
102
explicit_index = 0
103
+ array_count = 0
104
+ seen_arrays_yet = False
105
+ array_streak_started = False
106
+ array_streak_interrupted = False
83
107
for i in ind:
84
108
if i is None :
85
- newaxis_count = newaxis_count + 1
109
+ newaxis_count += 1
110
+ if array_streak_started:
111
+ array_streak_interrupted = True
86
112
elif i is Ellipsis :
87
- ellipses_count = ellipses_count + 1
113
+ ellipses_count += 1
114
+ if array_streak_started:
115
+ array_streak_interrupted = True
88
116
elif isinstance (i, slice ):
89
- axes_referenced = axes_referenced + 1
90
- elif is_integral(i):
91
- explicit_index = explicit_index + 1
92
- axes_referenced = axes_referenced + 1
93
- elif isinstance (i, list ):
94
- raise NotImplemented
117
+ axes_referenced += 1
118
+ if array_streak_started:
119
+ array_streak_interrupted = True
120
+ elif _is_integral(i):
121
+ explicit_index += 1
122
+ axes_referenced += 1
123
+ if array_streak_started:
124
+ array_streak_interrupted = True
125
+ elif isinstance (i, usm_ndarray):
126
+ if not seen_arrays_yet:
127
+ seen_arrays_yet = True
128
+ array_streak_started = True
129
+ array_streak_interrupted = False
130
+ if array_streak_interrupted:
131
+ raise IndexError (
132
+ " Advanced indexing array specs may not be "
133
+ " separated by basic slicing specs."
134
+ )
135
+ dt_k = i.dtype.kind
136
+ if dt_k == " b" :
137
+ axes_referenced += i.ndim
138
+ elif dt_k in " ui" :
139
+ axes_referenced += 1
140
+ else :
141
+ raise IndexError (
142
+ " arrays used as indices must be of integer (or boolean) type"
143
+ )
144
+ array_count += 1
95
145
else :
96
146
raise TypeError
97
147
if ellipses_count > 1 :
@@ -108,7 +158,10 @@ cdef object _basic_slice_meta(object ind, tuple shape,
108
158
+ axes_referenced - explicit_index)
109
159
new_shape = list ()
110
160
new_strides = list ()
161
+ new_advanced_ind = list ()
111
162
k = 0
163
+ new_advanced_start_pos = - 1
164
+ advanced_start_pos_set = False
112
165
new_offset = offset
113
166
is_empty = False
114
167
for i in range (len (ind)):
@@ -133,7 +186,7 @@ cdef object _basic_slice_meta(object ind, tuple shape,
133
186
if sh_i == 0 :
134
187
is_empty = True
135
188
k = k_new
136
- elif is_integral (ind_i):
189
+ elif _is_integral (ind_i):
137
190
ind_i = ind_i.__index__()
138
191
if 0 <= ind_i < shape[k]:
139
192
k_new = k + 1
@@ -149,8 +202,25 @@ cdef object _basic_slice_meta(object ind, tuple shape,
149
202
raise IndexError (
150
203
(" Index {0} is out of range for "
151
204
" axes {1} with size {2}" ).format(ind_i, k, shape[k]))
205
+ elif isinstance (ind_i, usm_ndarray):
206
+ if not advanced_start_pos_set:
207
+ new_advanced_start_pos = len (new_shape)
208
+ advanced_start_pos_set = True
209
+ new_advanced_ind.append(ind_i)
210
+ dt_k = ind_i.dtype.kind
211
+ if dt_k == " b" :
212
+ k_new = k + ind_i.ndim
213
+ else :
214
+ k_new = k + 1
215
+ new_shape.extend(shape[k:k_new])
216
+ new_strides.extend(strides[k:k_new])
217
+ k = k_new
152
218
new_shape.extend(shape[k:])
153
219
new_strides.extend(strides[k:])
154
- return (tuple (new_shape), tuple (new_strides), new_offset)
220
+ new_shape_len += len (shape) - k
221
+ # assert len(new_shape) == new_shape_len, f"{len(new_shape)} vs {new_shape_len}"
222
+ # assert len(new_strides) == new_shape_len, f"{len(new_strides)} vs {new_shape_len}"
223
+ # assert len(new_advanced_ind) == array_count
224
+ return (tuple (new_shape), tuple (new_strides), new_offset, tuple (new_advanced_ind), new_advanced_start_pos)
155
225
else :
156
226
raise TypeError
0 commit comments