|
2 | 2 | import sys
|
3 | 3 | import warnings
|
4 | 4 | from collections.abc import Callable, Iterable, Sequence
|
5 |
| -from itertools import chain, groupby |
| 5 | +from itertools import chain, groupby, zip_longest |
6 | 6 | from typing import cast, overload
|
7 | 7 |
|
8 | 8 | import numpy as np
|
|
39 | 39 | from pytensor.tensor.blockwise import vectorize_node_fallback
|
40 | 40 | from pytensor.tensor.elemwise import DimShuffle
|
41 | 41 | from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
|
42 |
| -from pytensor.tensor.math import clip |
| 42 | +from pytensor.tensor.math import add, clip |
43 | 43 | from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
|
44 | 44 | from pytensor.tensor.type import (
|
45 | 45 | TensorType,
|
|
63 | 63 | from pytensor.tensor.type_other import (
|
64 | 64 | MakeSlice,
|
65 | 65 | NoneConst,
|
| 66 | + NoneSliceConst, |
66 | 67 | NoneTypeT,
|
67 | 68 | SliceConstant,
|
68 | 69 | SliceType,
|
@@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
|
844 | 845 | return ps.as_scalar(a)
|
845 | 846 |
|
846 | 847 |
|
| 848 | +def slice_static_length(slc, dim_length): |
| 849 | + if dim_length is None: |
| 850 | + # TODO: Some cases must be zero by definition, we could handle those |
| 851 | + return None |
| 852 | + |
| 853 | + entries = [None, None, None] |
| 854 | + for i, entry in enumerate((slc.start, slc.stop, slc.step)): |
| 855 | + if entry is None: |
| 856 | + continue |
| 857 | + |
| 858 | + try: |
| 859 | + entries[i] = get_scalar_constant_value(entry) |
| 860 | + except NotScalarConstantError: |
| 861 | + return None |
| 862 | + |
| 863 | + return len(range(*slice(*entries).indices(dim_length))) |
| 864 | + |
| 865 | + |
847 | 866 | class Subtensor(COp):
|
848 | 867 | """Basic NumPy indexing operator."""
|
849 | 868 |
|
@@ -886,50 +905,15 @@ def make_node(self, x, *inputs):
|
886 | 905 | )
|
887 | 906 |
|
888 | 907 | padded = [
|
889 |
| - *get_idx_list((None, *inputs), self.idx_list), |
| 908 | + *indices_from_subtensor(inputs, self.idx_list), |
890 | 909 | *[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
|
891 | 910 | ]
|
892 | 911 |
|
893 |
| - out_shape = [] |
894 |
| - |
895 |
| - def extract_const(value): |
896 |
| - if value is None: |
897 |
| - return value, True |
898 |
| - try: |
899 |
| - value = get_scalar_constant_value(value) |
900 |
| - return value, True |
901 |
| - except NotScalarConstantError: |
902 |
| - return value, False |
903 |
| - |
904 |
| - for the_slice, length in zip(padded, x.type.shape, strict=True): |
905 |
| - if not isinstance(the_slice, slice): |
906 |
| - continue |
907 |
| - |
908 |
| - if length is None: |
909 |
| - out_shape.append(None) |
910 |
| - continue |
911 |
| - |
912 |
| - start = the_slice.start |
913 |
| - stop = the_slice.stop |
914 |
| - step = the_slice.step |
915 |
| - |
916 |
| - is_slice_const = True |
917 |
| - |
918 |
| - start, is_const = extract_const(start) |
919 |
| - is_slice_const = is_slice_const and is_const |
920 |
| - |
921 |
| - stop, is_const = extract_const(stop) |
922 |
| - is_slice_const = is_slice_const and is_const |
923 |
| - |
924 |
| - step, is_const = extract_const(step) |
925 |
| - is_slice_const = is_slice_const and is_const |
926 |
| - |
927 |
| - if not is_slice_const: |
928 |
| - out_shape.append(None) |
929 |
| - continue |
930 |
| - |
931 |
| - slice_length = len(range(*slice(start, stop, step).indices(length))) |
932 |
| - out_shape.append(slice_length) |
| 912 | + out_shape = [ |
| 913 | + slice_static_length(slc, length) |
| 914 | + for slc, length in zip(padded, x.type.shape, strict=True) |
| 915 | + if isinstance(slc, slice) |
| 916 | + ] |
933 | 917 |
|
934 | 918 | return Apply(
|
935 | 919 | self,
|
@@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op):
|
2826 | 2810 |
|
2827 | 2811 | __props__ = ()
|
2828 | 2812 |
|
2829 |
| - def make_node(self, x, *index): |
| 2813 | + def make_node(self, x, *indices): |
2830 | 2814 | x = as_tensor_variable(x)
|
2831 |
| - index = tuple(map(as_index_variable, index)) |
| 2815 | + indices = tuple(map(as_index_variable, indices)) |
| 2816 | + |
| 2817 | + explicit_indices = [] |
| 2818 | + new_axes = [] |
| 2819 | + for idx in indices: |
| 2820 | + if isinstance(idx.type, TensorType) and idx.dtype == "bool": |
| 2821 | + if idx.type.ndim == 0: |
| 2822 | + raise NotImplementedError( |
| 2823 | + "Indexing with scalar booleans not supported" |
| 2824 | + ) |
2832 | 2825 |
|
2833 |
| - # We create a fake symbolic shape tuple and identify the broadcast |
2834 |
| - # dimensions from the shape result of this entire subtensor operation. |
2835 |
| - with config.change_flags(compute_test_value="off"): |
2836 |
| - fake_shape = tuple( |
2837 |
| - tensor(dtype="int64", shape=()) if s != 1 else 1 for s in x.type.shape |
2838 |
| - ) |
| 2826 | + # Check static shape aligned |
| 2827 | + axis = len(explicit_indices) - len(new_axes) |
| 2828 | + indexed_shape = x.type.shape[axis : axis + idx.type.ndim] |
| 2829 | + for j, (indexed_length, indexer_length) in enumerate( |
| 2830 | + zip(indexed_shape, idx.type.shape) |
| 2831 | + ): |
| 2832 | + if ( |
| 2833 | + indexed_length is not None |
| 2834 | + and indexer_length is not None |
| 2835 | + and indexed_length != indexer_length |
| 2836 | + ): |
| 2837 | + raise IndexError( |
| 2838 | + f"boolean index did not match indexed tensor along axis {axis + j};" |
| 2839 | + f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" |
| 2840 | + ) |
| 2841 | + # Convert boolean indices to integer with nonzero, to reason about static shape next |
| 2842 | + if isinstance(idx, Constant): |
| 2843 | + nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()] |
| 2844 | + else: |
| 2845 | + # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero |
| 2846 | + # and seeing that other integer indices cannot possible match it |
| 2847 | + nonzero_indices = idx.nonzero() |
| 2848 | + explicit_indices.extend(nonzero_indices) |
| 2849 | + else: |
| 2850 | + if isinstance(idx.type, NoneTypeT): |
| 2851 | + new_axes.append(len(explicit_indices)) |
| 2852 | + explicit_indices.append(idx) |
2839 | 2853 |
|
2840 |
| - fake_index = tuple( |
2841 |
| - chain.from_iterable( |
2842 |
| - pytensor.tensor.basic.nonzero(idx) |
2843 |
| - if getattr(idx, "ndim", 0) > 0 |
2844 |
| - and getattr(idx, "dtype", None) == "bool" |
2845 |
| - else (idx,) |
2846 |
| - for idx in index |
2847 |
| - ) |
| 2854 | + if (len(explicit_indices) - len(new_axes)) > x.type.ndim: |
| 2855 | + raise IndexError( |
| 2856 | + f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed" |
2848 | 2857 | )
|
2849 | 2858 |
|
2850 |
| - out_shape = tuple( |
2851 |
| - i.value if isinstance(i, Constant) else None |
2852 |
| - for i in indexed_result_shape(fake_shape, fake_index) |
2853 |
| - ) |
| 2859 | + # Perform basic and advanced indexing shape inference separately |
| 2860 | + basic_group_shape = [] |
| 2861 | + advanced_indices = [] |
| 2862 | + adv_group_axis = None |
| 2863 | + last_adv_group_axis = None |
| 2864 | + expanded_x_shape = tuple( |
| 2865 | + np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) |
| 2866 | + ) |
| 2867 | + for i, (idx, dim_length) in enumerate( |
| 2868 | + zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) |
| 2869 | + ): |
| 2870 | + if isinstance(idx.type, NoneTypeT): |
| 2871 | + basic_group_shape.append(1) # New-axis |
| 2872 | + elif isinstance(idx.type, SliceType): |
| 2873 | + if isinstance(idx, Constant): |
| 2874 | + basic_group_shape.append(slice_static_length(idx.data, dim_length)) |
| 2875 | + elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice): |
| 2876 | + basic_group_shape.append( |
| 2877 | + slice_static_length(slice(*idx.owner.inputs), dim_length) |
| 2878 | + ) |
| 2879 | + else: |
| 2880 | + # Symbolic root slice (owner is None), or slice operation we don't understand |
| 2881 | + basic_group_shape.append(None) |
| 2882 | + else: # TensorType |
| 2883 | + # Keep track of advanced group axis |
| 2884 | + if adv_group_axis is None: |
| 2885 | + # First time we see an advanced index |
| 2886 | + adv_group_axis, last_adv_group_axis = i, i |
| 2887 | + elif last_adv_group_axis == (i - 1): |
| 2888 | + # Another advanced indexing aligned with the first group |
| 2889 | + last_adv_group_axis = i |
| 2890 | + else: |
| 2891 | + # Non-consecutive advanced index, all advanced index views get moved to the front |
| 2892 | + adv_group_axis = 0 |
| 2893 | + advanced_indices.append(idx) |
| 2894 | + |
| 2895 | + if advanced_indices: |
| 2896 | + try: |
| 2897 | + # Use variadic add to infer static shape of advanced integer indices |
| 2898 | + advanced_group_static_shape = add(*advanced_indices).type.shape |
| 2899 | + except ValueError: |
| 2900 | + # It fails when static shapes are inconsistent |
| 2901 | + static_shapes = [idx.type.shape for idx in advanced_indices] |
| 2902 | + raise IndexError( |
| 2903 | + f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}" |
| 2904 | + ) |
| 2905 | + # Combine advanced and basic views |
| 2906 | + indexed_shape = [ |
| 2907 | + *basic_group_shape[:adv_group_axis], |
| 2908 | + *advanced_group_static_shape, |
| 2909 | + *basic_group_shape[adv_group_axis:], |
| 2910 | + ] |
| 2911 | + else: |
| 2912 | + # This could have been a basic subtensor! |
| 2913 | + indexed_shape = basic_group_shape |
2854 | 2914 |
|
2855 | 2915 | return Apply(
|
2856 | 2916 | self,
|
2857 |
| - (x, *index), |
2858 |
| - [tensor(dtype=x.type.dtype, shape=out_shape)], |
| 2917 | + [x, *indices], |
| 2918 | + [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))], |
2859 | 2919 | )
|
2860 | 2920 |
|
2861 | 2921 | def R_op(self, inputs, eval_points):
|
|
0 commit comments