Skip to content

Commit 52f7d62

Browse files
committed
Support more cases of multi-dimensional advanced indexing and updating in Numba
Extends pre-existing rewrite to ravel multiple integer indices, and to place them consecutively. The following cases should now be supported without object mode: * Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices * Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.
1 parent e4c1af9 commit 52f7d62

File tree

4 files changed

+172
-50
lines changed

4 files changed

+172
-50
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
150150
for adv_idx in adv_idxs
151151
)
152152
# Must be consecutive
153-
and not op.non_contiguous_adv_indexing(node)
153+
and not op.non_consecutive_adv_indexing(node)
154154
# y in set/inc_subtensor cannot be broadcasted
155155
and (
156156
y is None

pytensor/tensor/rewriting/subtensor.py

Lines changed: 101 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
20292029
return [copy_stack_trace(node.outputs[0], new_out)]
20302030

20312031

2032-
@node_rewriter(tracks=[AdvancedSubtensor])
2032+
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
20332033
def ravel_multidimensional_int_idx(fgraph, node):
2034-
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
2035-
2036-
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
2034+
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
2035+
supported by Numba or by our specialized dispatchers
20372036
2037+
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
20382038
20392039
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
20402040
2041-
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2041+
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2042+
2043+
It also handles multiple integer indices, but only if they don't broadcast
2044+
2045+
x[eye(3,), 2:, eye(3)] -> x[eye(3), eye(3), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2046+
2047+
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
2048+
2049+
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
2050+
20422051
"""
2043-
x, *idxs = node.inputs
2052+
op = node.op
2053+
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
2054+
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
2055+
2056+
if is_inc_subtensor:
2057+
x, y, *idxs = node.inputs
2058+
# Inc/SetSubtensor is harder to reason about due to y
2059+
# We get out if it's broadcasting or if the advanced indices are non-consecutive
2060+
if non_consecutive_adv_indexing or (
2061+
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
2062+
):
2063+
return None
2064+
2065+
else:
2066+
x, *idxs = node.inputs
20442067

20452068
if any(
20462069
(
@@ -2049,39 +2072,89 @@ def ravel_multidimensional_int_idx(fgraph, node):
20492072
)
20502073
for idx in idxs
20512074
):
2052-
# Get out if there are any other advanced indexes or np.newaxis
2075+
# Get out if there are any other advanced indices or np.newaxis
20532076
return None
20542077

2055-
int_idxs = [
2078+
int_idxs_and_pos = [
20562079
(i, idx)
20572080
for i, idx in enumerate(idxs)
20582081
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
20592082
]
20602083

2061-
if len(int_idxs) != 1:
2062-
# Get out if there are no or multiple integer idxs
2084+
if not int_idxs_and_pos:
20632085
return None
20642086

2065-
[(int_idx_pos, int_idx)] = int_idxs
2066-
if int_idx.type.ndim < 2:
2067-
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
2087+
int_idxs_pos, int_idxs = zip(
2088+
*int_idxs_and_pos, strict=False
2089+
) # strict=False because by definition it's true
2090+
2091+
first_int_idx_pos = int_idxs_pos[0]
2092+
first_int_idx = int_idxs[0]
2093+
first_int_idx_bcast = first_int_idx.type.broadcastable
2094+
2095+
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
2096+
# We don't have a view-only broadcasting operation
2097+
# Explicitly broadcasting the indices can incur a memory / copy overhead
20682098
return None
20692099

2070-
raveled_int_idx = int_idx.ravel()
2071-
new_idxs = list(idxs)
2072-
new_idxs[int_idx_pos] = raveled_int_idx
2073-
raveled_subtensor = x[tuple(new_idxs)]
2074-
2075-
# Reshape into correct shape
2076-
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2077-
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2078-
raveled_shape = raveled_subtensor.shape
2079-
unraveled_shape = (
2080-
*raveled_shape[:int_idx_pos],
2081-
*int_idx.shape,
2082-
*raveled_shape[int_idx_pos + 1 :],
2083-
)
2084-
new_out = raveled_subtensor.reshape(unraveled_shape)
2100+
int_idxs_ndim = len(first_int_idx_bcast)
2101+
int_idxs_need_raveling = (
2102+
int_idxs_ndim != 1
2103+
) # 0-ndim would be basic indexing, but we handle just in case
2104+
2105+
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
2106+
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
2107+
return None
2108+
2109+
# Reorder non-consecutive indices
2110+
if non_consecutive_adv_indexing:
2111+
assert not is_inc_subtensor # Sanity check that we got out if this was the case
2112+
# This case works as if all the advanced indices were on the front
2113+
transposition = list(int_idxs_pos) + [
2114+
i for i in range(len(idxs)) if i not in int_idxs_pos
2115+
]
2116+
idxs = tuple(idxs[a] for a in transposition)
2117+
x = x.transpose(transposition)
2118+
first_int_idx_pos = 0
2119+
del int_idxs_pos # Make sure they are not wrongly used
2120+
2121+
# Ravel multidimensional indices
2122+
if int_idxs_need_raveling:
2123+
idxs = list(idxs)
2124+
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
2125+
idxs[idx_pos] = int_idx.ravel()
2126+
2127+
# Index with reordered and/or raveled indices
2128+
new_subtensor = x[tuple(idxs)]
2129+
2130+
if is_inc_subtensor:
2131+
int_idx_ndim = len(first_int_idx_bcast)
2132+
y_shape = tuple(y.shape)
2133+
y_raveled_shape = (
2134+
*y_shape[:first_int_idx_pos],
2135+
-1,
2136+
*y_shape[first_int_idx_pos + int_idx_ndim :],
2137+
)
2138+
y_raveled = y.reshape(y_raveled_shape)
2139+
2140+
new_out = inc_subtensor(
2141+
new_subtensor,
2142+
y_raveled,
2143+
set_instead_of_inc=op.set_instead_of_inc,
2144+
ignore_duplicates=op.ignore_duplicates,
2145+
inplace=op.inplace,
2146+
)
2147+
2148+
else:
2149+
# Unravel advanced indexing dimensions
2150+
raveled_shape = tuple(new_subtensor.shape)
2151+
unraveled_shape = (
2152+
*raveled_shape[:first_int_idx_pos],
2153+
*first_int_idx.shape,
2154+
*raveled_shape[first_int_idx_pos + 1 :],
2155+
)
2156+
new_out = new_subtensor.reshape(unraveled_shape)
2157+
20852158
return [copy_stack_trace(node.outputs[0], new_out)]
20862159

20872160

pytensor/tensor/subtensor.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import sys
3+
import warnings
34
from collections.abc import Callable, Iterable
45
from itertools import chain, groupby
56
from textwrap import dedent
@@ -571,7 +572,7 @@ def group_indices(indices):
571572
return idx_groups
572573

573574

574-
def _non_contiguous_adv_indexing(indices) -> bool:
575+
def _non_consecutive_adv_indexing(indices) -> bool:
575576
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
576577
idx_groups = group_indices(indices)
577578
# This means that there are at least two groups of advanced indexing separated by basic indexing
@@ -602,7 +603,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
602603
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
603604
idx_groups = group_indices(indices)
604605

605-
if _non_contiguous_adv_indexing(indices):
606+
if _non_consecutive_adv_indexing(indices):
606607
# In this case NumPy places the advanced index groups in the front of the array
607608
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
608609
idx_groups = sorted(idx_groups, key=lambda x: x[0])
@@ -2782,6 +2783,13 @@ def grad(self, inputs, grads):
27822783

27832784
@staticmethod
27842785
def non_contiguous_adv_indexing(node: Apply) -> bool:
2786+
warnings.warn(
2787+
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
2788+
)
2789+
return AdvancedSubtensor.non_consecutive_adv_indexing(node)
2790+
2791+
@staticmethod
2792+
def non_consecutive_adv_indexing(node: Apply) -> bool:
27852793
"""
27862794
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
27872795
@@ -2803,7 +2811,7 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
28032811
True if the advanced indexing is non-contiguous, False otherwise.
28042812
"""
28052813
_, *idxs = node.inputs
2806-
return _non_contiguous_adv_indexing(idxs)
2814+
return _non_consecutive_adv_indexing(idxs)
28072815

28082816

28092817
advanced_subtensor = AdvancedSubtensor()
@@ -2821,7 +2829,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
28212829
if isinstance(batch_idx, TensorVariable)
28222830
)
28232831

2824-
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
2832+
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
28252833
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
28262834
# which would put the indexed results to the left of the batch dimensions!
28272835
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
@@ -2940,6 +2948,13 @@ def grad(self, inpt, output_gradients):
29402948

29412949
@staticmethod
29422950
def non_contiguous_adv_indexing(node: Apply) -> bool:
2951+
warnings.warn(
2952+
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
2953+
)
2954+
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
2955+
2956+
@staticmethod
2957+
def non_consecutive_adv_indexing(node: Apply) -> bool:
29432958
"""
29442959
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
29452960
@@ -2961,7 +2976,7 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
29612976
True if the advanced indexing is non-contiguous, False otherwise.
29622977
"""
29632978
_, _, *idxs = node.inputs
2964-
return _non_contiguous_adv_indexing(idxs)
2979+
return _non_consecutive_adv_indexing(idxs)
29652980

29662981

29672982
advanced_inc_subtensor = AdvancedIncSubtensor()

tests/link/numba/test_subtensor.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
8181
(np.array([True, False, False])),
8282
False,
8383
),
84-
(
85-
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
86-
([1, 2], [2, 3]),
87-
False,
88-
),
8984
# Single multidimensional indexing (supported after specialization rewrites)
9085
(
9186
as_tensor(np.arange(3 * 3).reshape((3, 3))),
@@ -117,6 +112,12 @@ def test_AdvancedSubtensor1_out_of_bounds():
117112
(slice(2, None), np.eye(3).astype(bool)),
118113
False,
119114
),
115+
# Multiple vector indexing (supported by our dispatcher)
116+
(
117+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
118+
([1, 2], [2, 3]),
119+
False,
120+
),
120121
(
121122
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
122123
(slice(None), [1, 2], [3, 4]),
@@ -127,18 +128,35 @@ def test_AdvancedSubtensor1_out_of_bounds():
127128
([1, 2], [3, 4], [5, 6]),
128129
False,
129130
),
130-
# Non-contiguous vector indexing, only supported in obj mode
131+
# Non-consecutive vector indexing, supported by our dispatcher after rewriting
131132
(
132133
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
133134
([1, 2], slice(None), [3, 4]),
134-
True,
135+
False,
136+
),
137+
# Multiple multidimensional integer indexing (supported by our dispatcher)
138+
(
139+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
140+
([[1, 2], [2, 1]], [[0, 0], [0, 0]]),
141+
False,
142+
),
143+
(
144+
as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))),
145+
(slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]),
146+
False,
135147
),
136-
# >1d vector indexing, only supported in obj mode
148+
# Multiple multidimensional indexing with broadcasting, only supported in obj mode
137149
(
138150
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
139151
([[1, 2], [2, 1]], [0, 0]),
140152
True,
141153
),
154+
# multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode
155+
(
156+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
157+
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
158+
True,
159+
),
142160
],
143161
)
144162
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
@@ -297,15 +315,15 @@ def test_AdvancedIncSubtensor1(x, y, indices):
297315
(
298316
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
299317
-np.arange(4 * 5).reshape(4, 5),
300-
(0, [1, 2, 2, 3]), # Broadcasted vector index
318+
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
301319
True,
302320
False,
303321
True,
304322
),
305323
(
306324
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
307325
np.array([-99]), # Broadcasted value
308-
(0, [1, 2, 2, 3]), # Broadcasted vector index
326+
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
309327
True,
310328
False,
311329
True,
@@ -380,7 +398,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
380398
(
381399
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
382400
rng.poisson(size=(2, 4)),
383-
([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices
401+
([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
384402
False,
385403
True,
386404
True,
@@ -400,15 +418,23 @@ def test_AdvancedIncSubtensor1(x, y, indices):
400418
(
401419
np.arange(5),
402420
rng.poisson(size=(2, 2)),
403-
([[1, 2], [2, 3]]), # matrix indices
421+
([[1, 2], [2, 3]]), # matrix index
422+
False,
423+
False,
424+
False,
425+
),
426+
(
427+
np.arange(3 * 5).reshape((3, 5)),
428+
rng.poisson(size=(2, 2, 2)),
429+
(slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index
430+
False,
431+
False,
404432
False,
405-
False, # Gets converted to AdvancedIncSubtensor1
406-
True, # This is actually supported with the default `ignore_duplicates=False`
407433
),
408434
(
409435
np.arange(3 * 5).reshape((3, 5)),
410-
rng.poisson(size=(1, 2, 2)),
411-
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
436+
rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
437+
(slice(1, 3), [[1, 2], [2, 3]]),
412438
False,
413439
True,
414440
True,
@@ -421,6 +447,14 @@ def test_AdvancedIncSubtensor1(x, y, indices):
421447
False,
422448
False,
423449
),
450+
(
451+
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
452+
rng.poisson(size=(3, 2, 2)),
453+
(slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices
454+
False,
455+
False,
456+
False,
457+
),
424458
],
425459
)
426460
@pytest.mark.parametrize("inplace", (False, True))

0 commit comments

Comments
 (0)