Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_contiguous_adv_indexing(node)
and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
Expand Down
132 changes: 104 additions & 28 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
return [copy_stack_trace(node.outputs[0], new_out)]


@node_rewriter(tracks=[AdvancedSubtensor])
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba

x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers

x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))

NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices

x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes

It also handles multiple integer indices, but only if they don't broadcast

x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes

Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast

x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))

"""
x, *idxs = node.inputs
op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)

if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None

else:
x, *idxs = node.inputs

if any(
(
Expand All @@ -2049,50 +2072,103 @@ def ravel_multidimensional_int_idx(fgraph, node):
)
for idx in idxs
):
# Get out if there are any other advanced indexes or np.newaxis
# Get out if there are any other advanced indices or np.newaxis
return None

int_idxs = [
int_idxs_and_pos = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]

if len(int_idxs) != 1:
# Get out if there are no or multiple integer idxs
if not int_idxs_and_pos:
return None

[(int_idx_pos, int_idx)] = int_idxs
if int_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
int_idxs_pos, int_idxs = zip(
*int_idxs_and_pos, strict=False
) # strict=False because by definition it's true

first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable

if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None

raveled_int_idx = int_idx.ravel()
new_idxs = list(idxs)
new_idxs[int_idx_pos] = raveled_int_idx
raveled_subtensor = x[tuple(new_idxs)]

# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
raveled_shape = raveled_subtensor.shape
unraveled_shape = (
*raveled_shape[:int_idx_pos],
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
new_out = raveled_subtensor.reshape(unraveled_shape)
int_idxs_ndim = len(first_int_idx_bcast)
if (
int_idxs_ndim == 0
): # This should be a basic indexing operation, rewrite elsewhere
return None

int_idxs_need_raveling = int_idxs_ndim > 1
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return None

# Reorder non-consecutive indices
if non_consecutive_adv_indexing:
assert not is_inc_subtensor # Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used

# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()

# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]

if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)

new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)

else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = new_subtensor.reshape(unraveled_shape)

return [copy_stack_trace(node.outputs[0], new_out)]


optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)

optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
64 changes: 47 additions & 17 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import warnings
from collections.abc import Callable, Iterable
from itertools import chain, groupby
from textwrap import dedent
Expand Down Expand Up @@ -59,6 +60,7 @@
zscalar,
)
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
Expand Down Expand Up @@ -527,11 +529,20 @@ def basic_shape(shape, indices):
if isinstance(idx, slice):
res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner:
idx_inputs = idx.owner.inputs
if idx.owner is None:
if not isinstance(idx, Constant):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape += (None,)
continue
else:
sl: slice = idx.data
slice_inputs = (sl.start, sl.stop, sl.step)
elif isinstance(idx.owner.op, MakeSlice):
slice_inputs = idx.owner.inputs
else:
idx_inputs = (None,)
res_shape += (slice_len(slice(*idx_inputs), n),)
raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
res_shape += (slice_len(slice(*slice_inputs), n),)
elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT):
Expand Down Expand Up @@ -570,8 +581,8 @@ def group_indices(indices):
return idx_groups


def _non_contiguous_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
def _non_consecutive_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-consecutive (i.e., split by basic indexing)."""
idx_groups = group_indices(indices)
# This means that there are at least two groups of advanced indexing separated by basic indexing
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
Expand Down Expand Up @@ -601,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices)

if _non_contiguous_adv_indexing(indices):
if _non_consecutive_adv_indexing(indices):
# In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0])
Expand Down Expand Up @@ -2728,6 +2739,11 @@ def is_bool_index(idx):
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)
for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out)

adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
Expand Down Expand Up @@ -2781,10 +2797,17 @@ def grad(self, inputs, grads):

@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedSubtensor.non_consecutive_adv_indexing(node)

@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).

This function checks if the advanced indexing is non-contiguous,
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.

Expand All @@ -2799,10 +2822,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
return _non_consecutive_adv_indexing(idxs)


advanced_subtensor = AdvancedSubtensor()
Expand All @@ -2820,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
if isinstance(batch_idx, TensorVariable)
)

if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
Expand All @@ -2829,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported."
)
else:
Expand Down Expand Up @@ -2939,10 +2962,17 @@ def grad(self, inpt, output_gradients):

@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)

@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).

This function checks if the advanced indexing is non-contiguous,
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.

Expand All @@ -2957,10 +2987,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
return _non_consecutive_adv_indexing(idxs)


advanced_inc_subtensor = AdvancedIncSubtensor()
Expand Down
1 change: 0 additions & 1 deletion scripts/mypy-failing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pytensor/link/numba/dispatch/scan.py
pytensor/printing.py
pytensor/raise_op.py
pytensor/sparse/basic.py
pytensor/sparse/type.py
pytensor/tensor/basic.py
pytensor/tensor/blas_c.py
pytensor/tensor/blas_headers.py
Expand Down
8 changes: 7 additions & 1 deletion scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ def check_no_unexpected_results(mypy_lines: Iterable[str]):
print(*missing, sep="\n")
sys.exit(1)
cp = subprocess.run(
["mypy", "--show-error-codes", "pytensor"],
[
"mypy",
"--show-error-codes",
"--disable-error-code",
"annotation-unchecked",
"pytensor",
],
capture_output=True,
)
output = cp.stdout.decode()
Expand Down
Loading
Loading