Skip to content

Commit f6c22f9

Browse files
committed
Support consecutive vector indices in Numba backend
1 parent ae66e82 commit f6c22f9

File tree

4 files changed

+287
-67
lines changed

4 files changed

+287
-67
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
66
from pytensor.link.utils import compile_function_src, unique_name_generator
77
from pytensor.tensor import TensorType
8+
from pytensor.tensor.rewriting.subtensor import is_full_slice
89
from pytensor.tensor.subtensor import (
910
AdvancedIncSubtensor,
1011
AdvancedIncSubtensor1,
@@ -13,6 +14,7 @@
1314
IncSubtensor,
1415
Subtensor,
1516
)
17+
from pytensor.tensor.type_other import NoneTypeT, SliceType
1618

1719

1820
@numba_funcify.register(Subtensor)
@@ -104,18 +106,61 @@ def {function_name}({", ".join(input_names)}):
104106
@numba_funcify.register(AdvancedSubtensor)
105107
@numba_funcify.register(AdvancedIncSubtensor)
106108
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
107-
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
108-
adv_idxs_dims = [
109-
idx.type.ndim
109+
if isinstance(op, AdvancedSubtensor):
110+
x, y, idxs = node.inputs[0], None, node.inputs[1:]
111+
else:
112+
x, y, *idxs = node.inputs
113+
114+
basic_idxs = [
115+
idx
110116
for idx in idxs
111-
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
117+
if (
118+
isinstance(idx.type, NoneTypeT)
119+
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
120+
)
121+
]
122+
adv_idxs = [
123+
{
124+
"axis": i,
125+
"dtype": idx.type.dtype,
126+
"bcast": idx.type.broadcastable,
127+
"ndim": idx.type.ndim,
128+
}
129+
for i, idx in enumerate(idxs)
130+
if isinstance(idx.type, TensorType)
112131
]
113132

133+
# Special case for consecutive consecutive vector indices
134+
if (
135+
not basic_idxs
136+
and len(adv_idxs) >= 2
137+
# Must be integer vectors
138+
# Todo: we could allow shape=(1,) if this is the shape of x
139+
and all(
140+
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
141+
for adv_idx in adv_idxs
142+
)
143+
# Must be consecutive
144+
and not op.non_contiguous_adv_indexing(node)
145+
# y in set/inc_subtensor cannot be broadcasted
146+
and (
147+
y is None
148+
or y.type.broadcastable
149+
== (
150+
x.type.broadcastable[: adv_idxs[0]["axis"]]
151+
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
152+
)
153+
)
154+
):
155+
return numba_funcify_multiple_vector_indexing(op, node, **kwargs)
156+
157+
# Cases natively supported by Numba
114158
if (
115159
# Numba does not support indexes with more than one dimension
160+
any(idx["ndim"] > 1 for idx in adv_idxs)
116161
# Nor multiple vector indexes
117-
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
118-
# The default index implementation does not handle duplicate indices correctly
162+
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
163+
# The default PyTensor implementation does not handle duplicate indices correctly
119164
or (
120165
isinstance(op, AdvancedIncSubtensor)
121166
and not op.set_instead_of_inc
@@ -127,6 +172,87 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
127172
return numba_funcify_default_subtensor(op, node, **kwargs)
128173

129174

175+
def numba_funcify_multiple_vector_indexing(
176+
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
177+
):
178+
# Special-case implementation for multiple consecutive vector indices (and set/incsubtensor)
179+
if isinstance(op, AdvancedSubtensor):
180+
y, idxs = None, node.inputs[1:]
181+
else:
182+
y, *idxs = node.inputs[1:]
183+
184+
first_axis = next(
185+
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
186+
)
187+
try:
188+
after_last_axis = next(
189+
i
190+
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
191+
if not isinstance(idx.type, TensorType)
192+
)
193+
except StopIteration:
194+
after_last_axis = len(idxs)
195+
196+
if isinstance(op, AdvancedSubtensor):
197+
198+
@numba_njit
199+
def advanced_subtensor_multiple_vector(x, *idxs):
200+
none_slices = idxs[:first_axis]
201+
vec_idxs = idxs[first_axis:after_last_axis]
202+
203+
x_shape = x.shape
204+
idx_shape = vec_idxs[0].shape
205+
shape_bef = x_shape[:first_axis]
206+
shape_aft = x_shape[after_last_axis:]
207+
out_shape = (*shape_bef, *idx_shape, *shape_aft)
208+
out_buffer = np.empty(out_shape, dtype=x.dtype)
209+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
210+
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
211+
return out_buffer
212+
213+
return advanced_subtensor_multiple_vector
214+
215+
elif op.set_instead_of_inc:
216+
inplace = op.inplace
217+
218+
@numba_njit
219+
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
220+
vec_idxs = idxs[first_axis:after_last_axis]
221+
x_shape = x.shape
222+
223+
if inplace:
224+
out = x
225+
else:
226+
out = x.copy()
227+
228+
for outer in np.ndindex(x_shape[:first_axis]):
229+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
230+
out[(*outer, *scalar_idxs)] = y[*outer, i]
231+
return out
232+
233+
return advanced_set_subtensor_multiple_vector
234+
235+
else:
236+
inplace = op.inplace
237+
238+
@numba_njit
239+
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
240+
vec_idxs = idxs[first_axis:after_last_axis]
241+
x_shape = x.shape
242+
243+
if inplace:
244+
out = x
245+
else:
246+
out = x.copy()
247+
248+
for outer in np.ndindex(x_shape[:first_axis]):
249+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
250+
out[(*outer, *scalar_idxs)] += y[*outer, i]
251+
return out
252+
253+
return advanced_inc_subtensor_multiple_vector
254+
255+
130256
@numba_funcify.register(AdvancedIncSubtensor1)
131257
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
132258
inplace = op.inplace

pytensor/tensor/subtensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2937,6 +2937,31 @@ def grad(self, inpt, output_gradients):
29372937
gy = _sum_grad_over_bcasted_dims(y, gy)
29382938
return [gx, gy] + [DisconnectedType()() for _ in idxs]
29392939

2940+
@staticmethod
2941+
def non_contiguous_adv_indexing(node: Apply) -> bool:
2942+
"""
2943+
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
2944+
2945+
This function checks if the advanced indexing is non-contiguous,
2946+
in which case the advanced index dimensions are placed on the left of the
2947+
output array, regardless of their opriginal position.
2948+
2949+
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
2950+
2951+
2952+
Parameters
2953+
----------
2954+
node : Apply
2955+
The node of the AdvancedSubtensor operation.
2956+
2957+
Returns
2958+
-------
2959+
bool
2960+
True if the advanced indexing is non-contiguous, False otherwise.
2961+
"""
2962+
_, _, *idxs = node.inputs
2963+
return _non_contiguous_adv_indexing(idxs)
2964+
29402965

29412966
advanced_inc_subtensor = AdvancedIncSubtensor()
29422967
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)

tests/link/numba/test_basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,11 @@ def compare_numba_and_py(
228228
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
229229
inputs: Sequence["TensorLike"],
230230
assert_fn: Callable | None = None,
231+
*,
231232
numba_mode=numba_mode,
232233
py_mode=py_mode,
233234
updates=None,
235+
inplace: bool = False,
234236
eval_obj_mode: bool = True,
235237
) -> tuple[Callable, Any]:
236238
"""Function to compare python graph output and Numba compiled output for testing equality
@@ -276,7 +278,14 @@ def assert_fn(x, y):
276278
pytensor_py_fn = function(
277279
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
278280
)
279-
py_res = pytensor_py_fn(*inputs)
281+
282+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
283+
py_res = pytensor_py_fn(*test_inputs)
284+
285+
# Get some coverage (and catch errors in python mode before unreadable numba ones)
286+
if eval_obj_mode:
287+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
288+
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)
280289

281290
pytensor_numba_fn = function(
282291
fn_inputs,
@@ -285,11 +294,9 @@ def assert_fn(x, y):
285294
accept_inplace=True,
286295
updates=updates,
287296
)
288-
numba_res = pytensor_numba_fn(*inputs)
289297

290-
# Get some coverage
291-
if eval_obj_mode:
292-
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
298+
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
299+
numba_res = pytensor_numba_fn(*test_inputs)
293300

294301
if len(fn_outputs) > 1:
295302
for j, p in zip(numba_res, py_res, strict=True):

0 commit comments

Comments
 (0)