Skip to content

Commit 0fd160b

Browse files
committed
Implement static_shape inference for AdvancedSubtensor
1 parent f7cc0f0 commit 0fd160b

File tree

4 files changed

+222
-67
lines changed

4 files changed

+222
-67
lines changed

pytensor/tensor/subtensor.py

Lines changed: 125 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import warnings
44
from collections.abc import Callable, Iterable, Sequence
5-
from itertools import chain, groupby
5+
from itertools import chain, groupby, zip_longest
66
from typing import cast, overload
77

88
import numpy as np
@@ -39,7 +39,7 @@
3939
from pytensor.tensor.blockwise import vectorize_node_fallback
4040
from pytensor.tensor.elemwise import DimShuffle
4141
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
42-
from pytensor.tensor.math import clip
42+
from pytensor.tensor.math import add, clip
4343
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
4444
from pytensor.tensor.type import (
4545
TensorType,
@@ -63,6 +63,7 @@
6363
from pytensor.tensor.type_other import (
6464
MakeSlice,
6565
NoneConst,
66+
NoneSliceConst,
6667
NoneTypeT,
6768
SliceConstant,
6869
SliceType,
@@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
844845
return ps.as_scalar(a)
845846

846847

848+
def slice_static_length(slc, dim_length):
849+
if dim_length is None:
850+
# TODO: Some cases must be zero by definition, we could handle those
851+
return None
852+
853+
entries = [None, None, None]
854+
for i, entry in enumerate((slc.start, slc.stop, slc.step)):
855+
if entry is None:
856+
continue
857+
858+
try:
859+
entries[i] = get_scalar_constant_value(entry)
860+
except NotScalarConstantError:
861+
return None
862+
863+
return len(range(*slice(*entries).indices(dim_length)))
864+
865+
847866
class Subtensor(COp):
848867
"""Basic NumPy indexing operator."""
849868

@@ -886,50 +905,15 @@ def make_node(self, x, *inputs):
886905
)
887906

888907
padded = [
889-
*get_idx_list((None, *inputs), self.idx_list),
908+
*indices_from_subtensor(inputs, self.idx_list),
890909
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
891910
]
892911

893-
out_shape = []
894-
895-
def extract_const(value):
896-
if value is None:
897-
return value, True
898-
try:
899-
value = get_scalar_constant_value(value)
900-
return value, True
901-
except NotScalarConstantError:
902-
return value, False
903-
904-
for the_slice, length in zip(padded, x.type.shape, strict=True):
905-
if not isinstance(the_slice, slice):
906-
continue
907-
908-
if length is None:
909-
out_shape.append(None)
910-
continue
911-
912-
start = the_slice.start
913-
stop = the_slice.stop
914-
step = the_slice.step
915-
916-
is_slice_const = True
917-
918-
start, is_const = extract_const(start)
919-
is_slice_const = is_slice_const and is_const
920-
921-
stop, is_const = extract_const(stop)
922-
is_slice_const = is_slice_const and is_const
923-
924-
step, is_const = extract_const(step)
925-
is_slice_const = is_slice_const and is_const
926-
927-
if not is_slice_const:
928-
out_shape.append(None)
929-
continue
930-
931-
slice_length = len(range(*slice(start, stop, step).indices(length)))
932-
out_shape.append(slice_length)
912+
out_shape = [
913+
slice_static_length(slc, length)
914+
for slc, length in zip(padded, x.type.shape, strict=True)
915+
if isinstance(slc, slice)
916+
]
933917

934918
return Apply(
935919
self,
@@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op):
28262810

28272811
__props__ = ()
28282812

2829-
def make_node(self, x, *index):
2813+
def make_node(self, x, *indices):
28302814
x = as_tensor_variable(x)
2831-
index = tuple(map(as_index_variable, index))
2815+
indices = tuple(map(as_index_variable, indices))
2816+
2817+
explicit_indices = []
2818+
new_axes = []
2819+
for idx in indices:
2820+
if isinstance(idx.type, TensorType) and idx.dtype == "bool":
2821+
if idx.type.ndim == 0:
2822+
raise NotImplementedError(
2823+
"Indexing with scalar booleans not supported"
2824+
)
28322825

2833-
# We create a fake symbolic shape tuple and identify the broadcast
2834-
# dimensions from the shape result of this entire subtensor operation.
2835-
with config.change_flags(compute_test_value="off"):
2836-
fake_shape = tuple(
2837-
tensor(dtype="int64", shape=()) if s != 1 else 1 for s in x.type.shape
2838-
)
2826+
# Check static shape aligned
2827+
axis = len(explicit_indices) - len(new_axes)
2828+
indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
2829+
for j, (indexed_length, indexer_length) in enumerate(
2830+
zip(indexed_shape, idx.type.shape)
2831+
):
2832+
if (
2833+
indexed_length is not None
2834+
and indexer_length is not None
2835+
and indexed_length != indexer_length
2836+
):
2837+
raise IndexError(
2838+
f"boolean index did not match indexed tensor along axis {axis + j};"
2839+
f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
2840+
)
2841+
# Convert boolean indices to integer with nonzero, to reason about static shape next
2842+
if isinstance(idx, Constant):
2843+
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
2844+
else:
2845+
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
2846+
# and seeing that other integer indices cannot possible match it
2847+
nonzero_indices = idx.nonzero()
2848+
explicit_indices.extend(nonzero_indices)
2849+
else:
2850+
if isinstance(idx.type, NoneTypeT):
2851+
new_axes.append(len(explicit_indices))
2852+
explicit_indices.append(idx)
28392853

2840-
fake_index = tuple(
2841-
chain.from_iterable(
2842-
pytensor.tensor.basic.nonzero(idx)
2843-
if getattr(idx, "ndim", 0) > 0
2844-
and getattr(idx, "dtype", None) == "bool"
2845-
else (idx,)
2846-
for idx in index
2847-
)
2854+
if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
2855+
raise IndexError(
2856+
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
28482857
)
28492858

2850-
out_shape = tuple(
2851-
i.value if isinstance(i, Constant) else None
2852-
for i in indexed_result_shape(fake_shape, fake_index)
2853-
)
2859+
# Perform basic and advanced indexing shape inference separately
2860+
basic_group_shape = []
2861+
advanced_indices = []
2862+
adv_group_axis = None
2863+
last_adv_group_axis = None
2864+
expanded_x_shape = tuple(
2865+
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
2866+
)
2867+
for i, (idx, dim_length) in enumerate(
2868+
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
2869+
):
2870+
if isinstance(idx.type, NoneTypeT):
2871+
basic_group_shape.append(1) # New-axis
2872+
elif isinstance(idx.type, SliceType):
2873+
if isinstance(idx, Constant):
2874+
basic_group_shape.append(slice_static_length(idx.data, dim_length))
2875+
elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice):
2876+
basic_group_shape.append(
2877+
slice_static_length(slice(*idx.owner.inputs), dim_length)
2878+
)
2879+
else:
2880+
# Symbolic root slice (owner is None), or slice operation we don't understand
2881+
basic_group_shape.append(None)
2882+
else: # TensorType
2883+
# Keep track of advanced group axis
2884+
if adv_group_axis is None:
2885+
# First time we see an advanced index
2886+
adv_group_axis, last_adv_group_axis = i, i
2887+
elif last_adv_group_axis == (i - 1):
2888+
# Another advanced indexing aligned with the first group
2889+
last_adv_group_axis = i
2890+
else:
2891+
# Non-consecutive advanced index, all advanced index views get moved to the front
2892+
adv_group_axis = 0
2893+
advanced_indices.append(idx)
2894+
2895+
if advanced_indices:
2896+
try:
2897+
# Use variadic add to infer static shape of advanced integer indices
2898+
advanced_group_static_shape = add(*advanced_indices).type.shape
2899+
except ValueError:
2900+
# It fails when static shapes are inconsistent
2901+
static_shapes = [idx.type.shape for idx in advanced_indices]
2902+
raise IndexError(
2903+
f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}"
2904+
)
2905+
# Combine advanced and basic views
2906+
indexed_shape = [
2907+
*basic_group_shape[:adv_group_axis],
2908+
*advanced_group_static_shape,
2909+
*basic_group_shape[adv_group_axis:],
2910+
]
2911+
else:
2912+
# This could have been a basic subtensor!
2913+
indexed_shape = basic_group_shape
28542914

28552915
return Apply(
28562916
self,
2857-
(x, *index),
2858-
[tensor(dtype=x.type.dtype, shape=out_shape)],
2917+
[x, *indices],
2918+
[tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
28592919
)
28602920

28612921
def R_op(self, inputs, eval_points):

pytensor/tensor/type_other.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs):
114114
return SliceConstant(slicetype, x)
115115

116116

117+
NoneSliceConst = Constant(slicetype, slice(None), name="slice(None)")
118+
119+
117120
class NoneTypeT(Generic):
118121
"""
119122
Inherit from Generic to have c code working.
@@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
137140
return NoneConst
138141

139142

140-
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"]
143+
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"]

pytensor/tensor/variable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,9 @@ def includes_bool(args_el):
506506

507507
# Check if the number of dimensions isn't too large.
508508
if self.ndim < index_dim_count:
509-
raise IndexError("too many indices for array")
509+
raise IndexError(
510+
f"too many indices for tensor: tensor is {self.ndim}-dimensional, but {index_dim_count} were indexed"
511+
)
510512

511513
# Convert an Ellipsis if provided into an appropriate number of
512514
# slice(None).

tests/tensor/test_subtensor.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23
import sys
34
from io import StringIO
45

@@ -1847,6 +1848,95 @@ def setup_method(self):
18471848
self.ix2 = lmatrix()
18481849
self.ixr = lrow()
18491850

1851+
def test_static_shape(self):
1852+
x = tensor("x", shape=(None, None))
1853+
y = tensor("y", shape=(4, 5, 6))
1854+
idx1 = tensor("idx1", shape=(10,), dtype=int)
1855+
idx2 = tensor("idx2", shape=(3, None), dtype=int)
1856+
1857+
assert x[idx1].type.shape == (10, None)
1858+
assert x[:, idx1].type.shape == (None, 10)
1859+
assert x[idx2, :5].type.shape == (3, None, None)
1860+
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
1861+
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)
1862+
assert x[idx1, idx2].type.shape == (3, 10)
1863+
assert x[idx2, idx1].type.shape == (3, 10)
1864+
assert x[None, idx1, idx2].type.shape == (1, 3, 10)
1865+
assert x[idx1, None, idx2].type.shape == (3, 10, 1)
1866+
assert x[idx1, idx2, None].type.shape == (3, 10, 1)
1867+
1868+
assert y[idx1, idx2, ::-1].type.shape == (3, 10, 6)
1869+
assert y[idx1, ::-1, idx2].type.shape == (3, 10, 5)
1870+
assert y[::-1, idx1, idx2].type.shape == (4, 3, 10)
1871+
assert y[::-1, idx1, None, idx2].type.shape == (3, 10, 4, 1)
1872+
1873+
msg = re.escape(
1874+
"shape mismatch: indexing tensors could not be broadcast together with shapes [(10,), (9,)]"
1875+
)
1876+
with pytest.raises(IndexError, match=msg):
1877+
x[idx1, idx1[1:]]
1878+
1879+
def test_static_shape_boolean(self):
1880+
y = tensor("y", shape=(4, 5, 6))
1881+
idx1 = tensor("idx1", shape=(4,), dtype=int)
1882+
idx2 = tensor("idx2", shape=(3, None), dtype=int)
1883+
bool_idx1 = tensor("bool_idx1", shape=(4,), dtype=bool)
1884+
bool_idx2 = tensor(
1885+
"bool_idx2",
1886+
shape=(
1887+
None,
1888+
5,
1889+
),
1890+
dtype=bool,
1891+
)
1892+
1893+
assert y[bool_idx1].type.shape == (None, 5, 6)
1894+
assert y[bool_idx1, :, None:-4:-1].type.shape == (None, 5, 3)
1895+
assert y[bool_idx1, idx2].type.shape == (3, None, 6)
1896+
assert y[bool_idx1, idx1, :].type.shape == (4, 6)
1897+
assert y[bool_idx1, :, idx1].type.shape == (4, 5)
1898+
assert y[bool_idx1, idx1, idx2].type.shape == (3, 4)
1899+
assert y[None, bool_idx1, None, idx2, None, idx1].type.shape == (3, 4, 1, 1, 1)
1900+
1901+
assert y[bool_idx2, :].type.shape == (None, 6)
1902+
assert y[bool_idx2, idx1].type.shape == (4,)
1903+
assert y[bool_idx2, idx2].type.shape == (3, None)
1904+
1905+
msg = re.escape(
1906+
"too many indices for tensor: tensor is 3-dimensional, but 4 were indexed"
1907+
)
1908+
with pytest.raises(IndexError, match=msg):
1909+
y[bool_idx2, bool_idx2]
1910+
1911+
# Case that could conceivably be detected as index error at definition time
1912+
bad_idx = ptb.concatenate([idx1, idx1])
1913+
assert y[bool_idx1, bad_idx].type.shape == (8, 6)
1914+
1915+
def test_static_shape_constant_boolean(self):
1916+
y = tensor("y", shape=(None, None, None))
1917+
idx1 = tensor("idx1", shape=(3,), dtype=int)
1918+
idx2 = tensor("idx2", shape=(4, None), dtype=int)
1919+
1920+
bool_idx1 = constant(np.array([True, False, True, True]), name="bool_idx1")
1921+
bool_idx2 = constant(
1922+
np.array([[True, False, True, True], [True, False, False, True]]),
1923+
name="bool_idx2",
1924+
)
1925+
1926+
assert y[bool_idx1].type.shape == (3, None, None)
1927+
assert y[bool_idx1, :, idx1].type.shape == (3, None)
1928+
assert y[bool_idx1, :, idx2].type.shape == (4, 3, None)
1929+
1930+
assert y[bool_idx2].type.shape == (5, None)
1931+
assert y[bool_idx1, idx2].type.shape == (4, 3, None)
1932+
1933+
bad_idx = ptb.concatenate([idx1, idx1])
1934+
msg = re.escape(
1935+
"shape mismatch: indexing tensors could not be broadcast together with shapes [(3,), (6,)]"
1936+
)
1937+
with pytest.raises(IndexError, match=msg):
1938+
y[bool_idx1, bad_idx]
1939+
18501940
@pytest.mark.parametrize(
18511941
"inplace",
18521942
[

0 commit comments

Comments
 (0)