diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 334382d132..476477cc80 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2,7 +2,7 @@ import sys import warnings from collections.abc import Callable, Iterable, Sequence -from itertools import chain, groupby +from itertools import chain, groupby, zip_longest from typing import cast, overload import numpy as np @@ -39,7 +39,7 @@ from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError -from pytensor.tensor.math import clip +from pytensor.tensor.math import add, clip from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable from pytensor.tensor.type import ( TensorType, @@ -63,6 +63,7 @@ from pytensor.tensor.type_other import ( MakeSlice, NoneConst, + NoneSliceConst, NoneTypeT, SliceConstant, SliceType, @@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable: return ps.as_scalar(a) +def slice_static_length(slc, dim_length): + if dim_length is None: + # TODO: Some cases must be zero by definition, we could handle those + return None + + entries = [None, None, None] + for i, entry in enumerate((slc.start, slc.stop, slc.step)): + if entry is None: + continue + + try: + entries[i] = get_scalar_constant_value(entry) + except NotScalarConstantError: + return None + + return len(range(*slice(*entries).indices(dim_length))) + + class Subtensor(COp): """Basic NumPy indexing operator.""" @@ -886,50 +905,15 @@ def make_node(self, x, *inputs): ) padded = [ - *get_idx_list((None, *inputs), self.idx_list), + *indices_from_subtensor(inputs, self.idx_list), *[slice(None, None, None)] * (x.type.ndim - len(idx_list)), ] - out_shape = [] - - def extract_const(value): - if value is None: - return value, True - try: - value = get_scalar_constant_value(value) - return value, True - except NotScalarConstantError: - return value, False - - for the_slice, length in zip(padded, x.type.shape, strict=True): - if not isinstance(the_slice, slice): - continue - - if length is None: - out_shape.append(None) - continue - - start = the_slice.start - stop = the_slice.stop - step = the_slice.step - - is_slice_const = True - - start, is_const = extract_const(start) - is_slice_const = is_slice_const and is_const - - stop, is_const = extract_const(stop) - is_slice_const = is_slice_const and is_const - - step, is_const = extract_const(step) - is_slice_const = is_slice_const and is_const - - if not is_slice_const: - out_shape.append(None) - continue - - slice_length = len(range(*slice(start, stop, step).indices(length))) - out_shape.append(slice_length) + out_shape = [ + slice_static_length(slc, length) + for slc, length in zip(padded, x.type.shape, strict=True) + if isinstance(slc, slice) + ] return Apply( self, @@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op): __props__ = () - def make_node(self, x, *index): + def make_node(self, x, *indices): x = as_tensor_variable(x) - index = tuple(map(as_index_variable, index)) + indices = tuple(map(as_index_variable, indices)) + + explicit_indices = [] + new_axes = [] + for idx in indices: + if isinstance(idx.type, TensorType) and idx.dtype == "bool": + if idx.type.ndim == 0: + raise NotImplementedError( + "Indexing with scalar booleans not supported" + ) - # We create a fake symbolic shape tuple and identify the broadcast - # dimensions from the shape result of this entire subtensor operation. - with config.change_flags(compute_test_value="off"): - fake_shape = tuple( - tensor(dtype="int64", shape=()) if s != 1 else 1 for s in x.type.shape - ) + # Check static shape aligned + axis = len(explicit_indices) - len(new_axes) + indexed_shape = x.type.shape[axis : axis + idx.type.ndim] + for j, (indexed_length, indexer_length) in enumerate( + zip(indexed_shape, idx.type.shape) + ): + if ( + indexed_length is not None + and indexer_length is not None + and indexed_length != indexer_length + ): + raise IndexError( + f"boolean index did not match indexed tensor along axis {axis + j};" + f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" + ) + # Convert boolean indices to integer with nonzero, to reason about static shape next + if isinstance(idx, Constant): + nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()] + else: + # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero + # and seeing that other integer indices cannot possible match it + nonzero_indices = idx.nonzero() + explicit_indices.extend(nonzero_indices) + else: + if isinstance(idx.type, NoneTypeT): + new_axes.append(len(explicit_indices)) + explicit_indices.append(idx) - fake_index = tuple( - chain.from_iterable( - pytensor.tensor.basic.nonzero(idx) - if getattr(idx, "ndim", 0) > 0 - and getattr(idx, "dtype", None) == "bool" - else (idx,) - for idx in index - ) + if (len(explicit_indices) - len(new_axes)) > x.type.ndim: + raise IndexError( + f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed" ) - out_shape = tuple( - i.value if isinstance(i, Constant) else None - for i in indexed_result_shape(fake_shape, fake_index) - ) + # Perform basic and advanced indexing shape inference separately + basic_group_shape = [] + advanced_indices = [] + adv_group_axis = None + last_adv_group_axis = None + expanded_x_shape = tuple( + np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) + ) + for i, (idx, dim_length) in enumerate( + zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) + ): + if isinstance(idx.type, NoneTypeT): + basic_group_shape.append(1) # New-axis + elif isinstance(idx.type, SliceType): + if isinstance(idx, Constant): + basic_group_shape.append(slice_static_length(idx.data, dim_length)) + elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice): + basic_group_shape.append( + slice_static_length(slice(*idx.owner.inputs), dim_length) + ) + else: + # Symbolic root slice (owner is None), or slice operation we don't understand + basic_group_shape.append(None) + else: # TensorType + # Keep track of advanced group axis + if adv_group_axis is None: + # First time we see an advanced index + adv_group_axis, last_adv_group_axis = i, i + elif last_adv_group_axis == (i - 1): + # Another advanced indexing aligned with the first group + last_adv_group_axis = i + else: + # Non-consecutive advanced index, all advanced index views get moved to the front + adv_group_axis = 0 + advanced_indices.append(idx) + + if advanced_indices: + try: + # Use variadic add to infer static shape of advanced integer indices + advanced_group_static_shape = add(*advanced_indices).type.shape + except ValueError: + # It fails when static shapes are inconsistent + static_shapes = [idx.type.shape for idx in advanced_indices] + raise IndexError( + f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}" + ) + # Combine advanced and basic views + indexed_shape = [ + *basic_group_shape[:adv_group_axis], + *advanced_group_static_shape, + *basic_group_shape[adv_group_axis:], + ] + else: + # This could have been a basic subtensor! + indexed_shape = basic_group_shape return Apply( self, - (x, *index), - [tensor(dtype=x.type.dtype, shape=out_shape)], + [x, *indices], + [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))], ) def R_op(self, inputs, eval_points): diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index a9e559504f..1454b18cde 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs): return SliceConstant(slicetype, x) +NoneSliceConst = Constant(slicetype, slice(None), name="slice(None)") + + class NoneTypeT(Generic): """ Inherit from Generic to have c code working. @@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs): return NoneConst -__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"] +__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"] diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 56fe76da0c..474d08c49d 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -506,7 +506,9 @@ def includes_bool(args_el): # Check if the number of dimensions isn't too large. if self.ndim < index_dim_count: - raise IndexError("too many indices for array") + raise IndexError( + f"too many indices for tensor: tensor is {self.ndim}-dimensional, but {index_dim_count} were indexed" + ) # Convert an Ellipsis if provided into an appropriate number of # slice(None). diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 20ca3a2420..ead94371d3 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -1,4 +1,5 @@ import logging +import re import sys from io import StringIO @@ -1847,6 +1848,95 @@ def setup_method(self): self.ix2 = lmatrix() self.ixr = lrow() + def test_static_shape(self): + x = tensor("x", shape=(None, None)) + y = tensor("y", shape=(4, 5, 6)) + idx1 = tensor("idx1", shape=(10,), dtype=int) + idx2 = tensor("idx2", shape=(3, None), dtype=int) + + assert x[idx1].type.shape == (10, None) + assert x[:, idx1].type.shape == (None, 10) + assert x[idx2, :5].type.shape == (3, None, None) + assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5) + assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3) + assert x[idx1, idx2].type.shape == (3, 10) + assert x[idx2, idx1].type.shape == (3, 10) + assert x[None, idx1, idx2].type.shape == (1, 3, 10) + assert x[idx1, None, idx2].type.shape == (3, 10, 1) + assert x[idx1, idx2, None].type.shape == (3, 10, 1) + + assert y[idx1, idx2, ::-1].type.shape == (3, 10, 6) + assert y[idx1, ::-1, idx2].type.shape == (3, 10, 5) + assert y[::-1, idx1, idx2].type.shape == (4, 3, 10) + assert y[::-1, idx1, None, idx2].type.shape == (3, 10, 4, 1) + + msg = re.escape( + "shape mismatch: indexing tensors could not be broadcast together with shapes [(10,), (9,)]" + ) + with pytest.raises(IndexError, match=msg): + x[idx1, idx1[1:]] + + def test_static_shape_boolean(self): + y = tensor("y", shape=(4, 5, 6)) + idx1 = tensor("idx1", shape=(4,), dtype=int) + idx2 = tensor("idx2", shape=(3, None), dtype=int) + bool_idx1 = tensor("bool_idx1", shape=(4,), dtype=bool) + bool_idx2 = tensor( + "bool_idx2", + shape=( + None, + 5, + ), + dtype=bool, + ) + + assert y[bool_idx1].type.shape == (None, 5, 6) + assert y[bool_idx1, :, None:-4:-1].type.shape == (None, 5, 3) + assert y[bool_idx1, idx2].type.shape == (3, None, 6) + assert y[bool_idx1, idx1, :].type.shape == (4, 6) + assert y[bool_idx1, :, idx1].type.shape == (4, 5) + assert y[bool_idx1, idx1, idx2].type.shape == (3, 4) + assert y[None, bool_idx1, None, idx2, None, idx1].type.shape == (3, 4, 1, 1, 1) + + assert y[bool_idx2, :].type.shape == (None, 6) + assert y[bool_idx2, idx1].type.shape == (4,) + assert y[bool_idx2, idx2].type.shape == (3, None) + + msg = re.escape( + "too many indices for tensor: tensor is 3-dimensional, but 4 were indexed" + ) + with pytest.raises(IndexError, match=msg): + y[bool_idx2, bool_idx2] + + # Case that could conceivably be detected as index error at definition time + bad_idx = ptb.concatenate([idx1, idx1]) + assert y[bool_idx1, bad_idx].type.shape == (8, 6) + + def test_static_shape_constant_boolean(self): + y = tensor("y", shape=(None, None, None)) + idx1 = tensor("idx1", shape=(3,), dtype=int) + idx2 = tensor("idx2", shape=(4, None), dtype=int) + + bool_idx1 = constant(np.array([True, False, True, True]), name="bool_idx1") + bool_idx2 = constant( + np.array([[True, False, True, True], [True, False, False, True]]), + name="bool_idx2", + ) + + assert y[bool_idx1].type.shape == (3, None, None) + assert y[bool_idx1, :, idx1].type.shape == (3, None) + assert y[bool_idx1, :, idx2].type.shape == (4, 3, None) + + assert y[bool_idx2].type.shape == (5, None) + assert y[bool_idx1, idx2].type.shape == (4, 3, None) + + bad_idx = ptb.concatenate([idx1, idx1]) + msg = re.escape( + "shape mismatch: indexing tensors could not be broadcast together with shapes [(3,), (6,)]" + ) + with pytest.raises(IndexError, match=msg): + y[bool_idx1, bad_idx] + @pytest.mark.parametrize( "inplace", [