|
33 | 33 | alloc, |
34 | 34 | get_scalar_constant_value, |
35 | 35 | nonzero, |
| 36 | + switch, |
36 | 37 | ) |
37 | 38 | from pytensor.tensor.basic import ( |
38 | 39 | constant as tensor_constant, |
39 | 40 | ) |
40 | 41 | from pytensor.tensor.blockwise import vectorize_node_fallback |
41 | 42 | from pytensor.tensor.elemwise import DimShuffle |
42 | 43 | from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError |
43 | | -from pytensor.tensor.math import clip |
| 44 | +from pytensor.tensor.math import abs as pt_abs |
| 45 | +from pytensor.tensor.math import clip, eq, ge, lt, maximum, minimum, sign |
44 | 46 | from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable |
45 | 47 | from pytensor.tensor.type import ( |
46 | 48 | TensorType, |
|
55 | 57 | lscalar, |
56 | 58 | tensor, |
57 | 59 | ubscalar, |
| 60 | + uint_dtypes, |
58 | 61 | uiscalar, |
59 | 62 | ulscalar, |
60 | 63 | uwscalar, |
@@ -254,6 +257,25 @@ def get_idx_list(inputs, idx_list): |
254 | 257 | return indices_from_subtensor(inputs[1:], idx_list) |
255 | 258 |
|
256 | 259 |
|
| 260 | +def undo_scalarization(x) -> TensorVariable: |
| 261 | + """Undo scalarization of a variable. |
| 262 | +
|
| 263 | + PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. |
| 264 | + But reasoning symbolically about the result of multiple indexing operations, we usually |
| 265 | + want to work on TensorVariables, since rewrites work on those and not ScalarVariables. |
| 266 | +
|
| 267 | + This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. |
| 268 | + """ |
| 269 | + if isinstance(x, ScalarVariable): |
| 270 | + if isinstance(x, ScalarConstant): |
| 271 | + return tensor_constant(x.data, dtype=x.dtype) |
| 272 | + elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): |
| 273 | + return x.owner.inputs[0] |
| 274 | + else: |
| 275 | + return as_tensor_variable(x) |
| 276 | + return x |
| 277 | + |
| 278 | + |
257 | 279 | @overload |
258 | 280 | def get_canonical_form_slice( |
259 | 281 | theslice: slice, |
@@ -296,25 +318,6 @@ def get_canonical_form_slice( |
296 | 318 | direction |
297 | 319 | Direction to iterate the resulting elements in. (-1 or 1). May be symbolic. |
298 | 320 | """ |
299 | | - from pytensor.tensor import ge, lt, sign, switch |
300 | | - |
301 | | - def undo_scalarization(x): |
302 | | - """Undo scalarization of a variable. |
303 | | -
|
304 | | - PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. |
305 | | - But reasoning symbolically about the result of multiple indexing operations, we usually |
306 | | - want to work on TensorVariables, since rewrites work on those and not ScalarVariables. |
307 | | -
|
308 | | - This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. |
309 | | - """ |
310 | | - if isinstance(x, ScalarVariable): |
311 | | - if isinstance(x, ScalarConstant): |
312 | | - return tensor_constant(x.data, dtype=x.dtype) |
313 | | - elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): |
314 | | - return x.owner.inputs[0] |
315 | | - else: |
316 | | - return as_tensor_variable(x) |
317 | | - return x |
318 | 321 |
|
319 | 322 | def analyze(x): |
320 | 323 | try: |
@@ -845,6 +848,17 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable: |
845 | 848 | return ps.as_scalar(a) |
846 | 849 |
|
847 | 850 |
|
| 851 | +def _eager_switch( |
| 852 | + cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable |
| 853 | +) -> TensorVariable: |
| 854 | + # Do not create a switch if cond is True/False |
| 855 | + # We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64 |
| 856 | + # It also simplifies immediately the graph that's returned |
| 857 | + if isinstance(cond, bool): |
| 858 | + return a if cond else b |
| 859 | + return switch(cond, a, b) |
| 860 | + |
| 861 | + |
848 | 862 | class Subtensor(COp): |
849 | 863 | """Basic NumPy indexing operator.""" |
850 | 864 |
|
@@ -956,27 +970,112 @@ def infer_shape(self, fgraph, node, shapes): |
956 | 970 | padded = actual_idx_list + [slice(None, None, None)] * ( |
957 | 971 | len(xshp) - len(self.idx_list) |
958 | 972 | ) |
| 973 | + |
| 974 | + zero = tensor_constant(np.array(0, dtype="int64")) |
| 975 | + one = tensor_constant(np.array(1, dtype="int64")) |
959 | 976 | i = 0 |
960 | 977 | for idx, xl in zip(padded, xshp, strict=True): |
961 | 978 | if isinstance(idx, slice): |
962 | | - # If it is the default (None, None, None) slice, or a variant, |
963 | | - # the shape will be xl |
| 979 | + a, b, step = idx.start, idx.stop, idx.step |
964 | 980 | if ( |
965 | | - (idx.start in [None, 0]) |
966 | | - and (idx.stop in [None, sys.maxsize]) |
967 | | - and (idx.step is None or idx.step == 1) |
| 981 | + a is None |
| 982 | + and b is None |
| 983 | + and step is not None |
| 984 | + and get_scalar_constant_value(step, raise_not_constant=False) == -1 |
968 | 985 | ): |
| 986 | + # Shortcut for x[::-1] |
969 | 987 | outshp.append(xl) |
| 988 | + |
970 | 989 | else: |
971 | | - cnf = get_canonical_form_slice(idx, xl)[0] |
972 | | - if cnf.step == 1: |
973 | | - length = cnf.stop - cnf.start |
| 990 | + if step is None: |
| 991 | + step_pos = True |
| 992 | + unit_step = True |
| 993 | + abs_step = one |
| 994 | + else: |
| 995 | + step = undo_scalarization(step) |
| 996 | + if step.dtype in uint_dtypes: |
| 997 | + step_pos = True |
| 998 | + abs_step = step.astype("int64") |
| 999 | + else: |
| 1000 | + step_pos = ge(step, zero) |
| 1001 | + abs_step = pt_abs(step) |
| 1002 | + unit_step = eq(abs_step, one) |
| 1003 | + |
| 1004 | + if a is None: |
| 1005 | + a_pos = True |
| 1006 | + a = _eager_switch(step_pos, zero, xl) |
974 | 1007 | else: |
975 | | - length = (cnf.stop - cnf.start - 1) // cnf.step + 1 |
976 | | - outshp.append(length) |
| 1008 | + a = undo_scalarization(a) |
| 1009 | + if a.dtype in uint_dtypes: |
| 1010 | + a_pos = True |
| 1011 | + a = a.astype("int64") |
| 1012 | + else: |
| 1013 | + a_pos = ge(a, zero) |
| 1014 | + |
| 1015 | + if b is None: |
| 1016 | + # For negative steps there is no numerical equivalent for stop=None. |
| 1017 | + # The formulas below work if we set it to -1 and consider `b_pos=True` |
| 1018 | + b_pos = True |
| 1019 | + b = _eager_switch(step_pos, xl, -one) |
| 1020 | + else: |
| 1021 | + b = undo_scalarization(b) |
| 1022 | + if b.dtype in uint_dtypes: |
| 1023 | + b = b.astype("int64") |
| 1024 | + b_pos = True |
| 1025 | + else: |
| 1026 | + b_pos = ge(b, zero) |
| 1027 | + |
| 1028 | + slice_length_pos_step = _eager_switch( |
| 1029 | + a_pos, |
| 1030 | + _eager_switch( |
| 1031 | + b_pos, |
| 1032 | + minimum(b - a, xl - a), # [a: b] |
| 1033 | + ((xl + b) - a), # [a: -b] |
| 1034 | + ), |
| 1035 | + _eager_switch( |
| 1036 | + b_pos, |
| 1037 | + # The [-a: b] is peculiar, the slice length actually decreases for larger arrays |
| 1038 | + # The branch -a is useless when b - a / 2 <= -a. Similar for the branch b |
| 1039 | + minimum(minimum(xl, b - a - xl), minimum(-a, b)), # [-a: b] |
| 1040 | + minimum(b - a, xl + b), # [-a: -b] |
| 1041 | + ), |
| 1042 | + ) |
| 1043 | + |
| 1044 | + slice_length_neg_step = _eager_switch( |
| 1045 | + a_pos, |
| 1046 | + _eager_switch( |
| 1047 | + b_pos, |
| 1048 | + minimum(a - b, xl - b - one), # [a: b] |
| 1049 | + minimum( |
| 1050 | + minimum(xl, a - (xl + b)), minimum(a + one, -b - one) |
| 1051 | + ), # [a: -b] |
| 1052 | + ), |
| 1053 | + _eager_switch( |
| 1054 | + b_pos, |
| 1055 | + ((xl + a) - b), # [-a: b] |
| 1056 | + minimum(a - b, xl + a + one), # [-a: -b] |
| 1057 | + ), |
| 1058 | + ) |
| 1059 | + |
| 1060 | + slice_length = _eager_switch( |
| 1061 | + step_pos, |
| 1062 | + slice_length_pos_step, |
| 1063 | + slice_length_neg_step, |
| 1064 | + ) |
| 1065 | + |
| 1066 | + # Incorporate step size |
| 1067 | + slice_length = _eager_switch( |
| 1068 | + unit_step, |
| 1069 | + slice_length, |
| 1070 | + (slice_length - one) // abs_step + one, |
| 1071 | + ) |
| 1072 | + # Catch negative sizes |
| 1073 | + slice_length = maximum(zero, slice_length) |
| 1074 | + outshp.append(slice_length) |
| 1075 | + |
977 | 1076 | i += 1 |
978 | 1077 | else: |
979 | | - # That dimension is dropped |
| 1078 | + # That dimension is dropped by integer indexing |
980 | 1079 | pass |
981 | 1080 | assert i == node.outputs[0].ndim |
982 | 1081 | assert len(outshp) == node.outputs[0].ndim |
|
0 commit comments