Skip to content

Commit 0e7a27d

Browse files
committed
Simplify graph returned by Subtensor.infer_shape
1 parent 2be9843 commit 0e7a27d

File tree

2 files changed

+195
-33
lines changed

2 files changed

+195
-33
lines changed

pytensor/tensor/subtensor.py

Lines changed: 130 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@
3333
alloc,
3434
get_scalar_constant_value,
3535
nonzero,
36+
switch,
3637
)
3738
from pytensor.tensor.basic import (
3839
constant as tensor_constant,
3940
)
4041
from pytensor.tensor.blockwise import vectorize_node_fallback
4142
from pytensor.tensor.elemwise import DimShuffle
4243
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
43-
from pytensor.tensor.math import clip
44+
from pytensor.tensor.math import abs as pt_abs
45+
from pytensor.tensor.math import clip, eq, ge, lt, maximum, minimum, sign
4446
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
4547
from pytensor.tensor.type import (
4648
TensorType,
@@ -55,6 +57,7 @@
5557
lscalar,
5658
tensor,
5759
ubscalar,
60+
uint_dtypes,
5861
uiscalar,
5962
ulscalar,
6063
uwscalar,
@@ -254,6 +257,25 @@ def get_idx_list(inputs, idx_list):
254257
return indices_from_subtensor(inputs[1:], idx_list)
255258

256259

260+
def undo_scalarization(x) -> TensorVariable:
261+
"""Undo scalarization of a variable.
262+
263+
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
264+
But reasoning symbolically about the result of multiple indexing operations, we usually
265+
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
266+
267+
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
268+
"""
269+
if isinstance(x, ScalarVariable):
270+
if isinstance(x, ScalarConstant):
271+
return tensor_constant(x.data, dtype=x.dtype)
272+
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
273+
return x.owner.inputs[0]
274+
else:
275+
return as_tensor_variable(x)
276+
return x
277+
278+
257279
@overload
258280
def get_canonical_form_slice(
259281
theslice: slice,
@@ -296,25 +318,6 @@ def get_canonical_form_slice(
296318
direction
297319
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
298320
"""
299-
from pytensor.tensor import ge, lt, sign, switch
300-
301-
def undo_scalarization(x):
302-
"""Undo scalarization of a variable.
303-
304-
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305-
But reasoning symbolically about the result of multiple indexing operations, we usually
306-
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307-
308-
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309-
"""
310-
if isinstance(x, ScalarVariable):
311-
if isinstance(x, ScalarConstant):
312-
return tensor_constant(x.data, dtype=x.dtype)
313-
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
314-
return x.owner.inputs[0]
315-
else:
316-
return as_tensor_variable(x)
317-
return x
318321

319322
def analyze(x):
320323
try:
@@ -845,6 +848,17 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
845848
return ps.as_scalar(a)
846849

847850

851+
def _eager_switch(
852+
cond: TensorVariable | bool, a: TensorVariable, b: TensorVariable
853+
) -> TensorVariable:
854+
# Do not create a switch if cond is True/False
855+
# We need this because uint types cannot be negative and creating the lazy switch could upcast everything to float64
856+
# It also simplifies immediately the graph that's returned
857+
if isinstance(cond, bool):
858+
return a if cond else b
859+
return switch(cond, a, b)
860+
861+
848862
class Subtensor(COp):
849863
"""Basic NumPy indexing operator."""
850864

@@ -956,27 +970,112 @@ def infer_shape(self, fgraph, node, shapes):
956970
padded = actual_idx_list + [slice(None, None, None)] * (
957971
len(xshp) - len(self.idx_list)
958972
)
973+
974+
zero = tensor_constant(np.array(0, dtype="int64"))
975+
one = tensor_constant(np.array(1, dtype="int64"))
959976
i = 0
960977
for idx, xl in zip(padded, xshp, strict=True):
961978
if isinstance(idx, slice):
962-
# If it is the default (None, None, None) slice, or a variant,
963-
# the shape will be xl
979+
a, b, step = idx.start, idx.stop, idx.step
964980
if (
965-
(idx.start in [None, 0])
966-
and (idx.stop in [None, sys.maxsize])
967-
and (idx.step is None or idx.step == 1)
981+
a is None
982+
and b is None
983+
and step is not None
984+
and get_scalar_constant_value(step, raise_not_constant=False) == -1
968985
):
986+
# Shortcut for x[::-1]
969987
outshp.append(xl)
988+
970989
else:
971-
cnf = get_canonical_form_slice(idx, xl)[0]
972-
if cnf.step == 1:
973-
length = cnf.stop - cnf.start
990+
if step is None:
991+
step_pos = True
992+
unit_step = True
993+
abs_step = one
994+
else:
995+
step = undo_scalarization(step)
996+
if step.dtype in uint_dtypes:
997+
step_pos = True
998+
abs_step = step.astype("int64")
999+
else:
1000+
step_pos = ge(step, zero)
1001+
abs_step = pt_abs(step)
1002+
unit_step = eq(abs_step, one)
1003+
1004+
if a is None:
1005+
a_pos = True
1006+
a = _eager_switch(step_pos, zero, xl)
9741007
else:
975-
length = (cnf.stop - cnf.start - 1) // cnf.step + 1
976-
outshp.append(length)
1008+
a = undo_scalarization(a)
1009+
if a.dtype in uint_dtypes:
1010+
a_pos = True
1011+
a = a.astype("int64")
1012+
else:
1013+
a_pos = ge(a, zero)
1014+
1015+
if b is None:
1016+
# For negative steps there is no numerical equivalent for stop=None.
1017+
# The formulas below work if we set it to -1 and consider `b_pos=True`
1018+
b_pos = True
1019+
b = _eager_switch(step_pos, xl, -one)
1020+
else:
1021+
b = undo_scalarization(b)
1022+
if b.dtype in uint_dtypes:
1023+
b = b.astype("int64")
1024+
b_pos = True
1025+
else:
1026+
b_pos = ge(b, zero)
1027+
1028+
slice_length_pos_step = _eager_switch(
1029+
a_pos,
1030+
_eager_switch(
1031+
b_pos,
1032+
minimum(b - a, xl - a), # [a: b]
1033+
((xl + b) - a), # [a: -b]
1034+
),
1035+
_eager_switch(
1036+
b_pos,
1037+
# The [-a: b] is peculiar, the slice length actually decreases for larger arrays
1038+
# The branch -a is useless when b - a / 2 <= -a. Similar for the branch b
1039+
minimum(minimum(xl, b - a - xl), minimum(-a, b)), # [-a: b]
1040+
minimum(b - a, xl + b), # [-a: -b]
1041+
),
1042+
)
1043+
1044+
slice_length_neg_step = _eager_switch(
1045+
a_pos,
1046+
_eager_switch(
1047+
b_pos,
1048+
minimum(a - b, xl - b - one), # [a: b]
1049+
minimum(
1050+
minimum(xl, a - (xl + b)), minimum(a + one, -b - one)
1051+
), # [a: -b]
1052+
),
1053+
_eager_switch(
1054+
b_pos,
1055+
((xl + a) - b), # [-a: b]
1056+
minimum(a - b, xl + a + one), # [-a: -b]
1057+
),
1058+
)
1059+
1060+
slice_length = _eager_switch(
1061+
step_pos,
1062+
slice_length_pos_step,
1063+
slice_length_neg_step,
1064+
)
1065+
1066+
# Incorporate step size
1067+
slice_length = _eager_switch(
1068+
unit_step,
1069+
slice_length,
1070+
(slice_length - one) // abs_step + one,
1071+
)
1072+
# Catch negative sizes
1073+
slice_length = maximum(zero, slice_length)
1074+
outshp.append(slice_length)
1075+
9771076
i += 1
9781077
else:
979-
# That dimension is dropped
1078+
# That dimension is dropped by integer indexing
9801079
pass
9811080
assert i == node.outputs[0].ndim
9821081
assert len(outshp) == node.outputs[0].ndim

tests/tensor/test_subtensor.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from pytensor.compile.mode import Mode
1616
from pytensor.configdefaults import config
1717
from pytensor.gradient import grad
18-
from pytensor.graph import Constant
18+
from pytensor.graph import Constant, FunctionGraph
1919
from pytensor.graph.basic import equal_computations
2020
from pytensor.graph.op import get_test_value
21-
from pytensor.graph.rewriting.utils import is_same_graph
21+
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
2222
from pytensor.printing import pprint
2323
from pytensor.scalar.basic import as_scalar, int16
2424
from pytensor.tensor import as_tensor, get_vector_length, vectorize
@@ -71,6 +71,7 @@
7171
lscalar,
7272
lvector,
7373
matrix,
74+
scalar,
7475
tensor,
7576
tensor3,
7677
tensor4,
@@ -1466,6 +1467,68 @@ def test_adv1_inc_sub_notlastdim_1_2dval_no_broadcast(self):
14661467
assert np.allclose(m2_val, m2_ref), (m2_val, m2_ref)
14671468

14681469

1470+
class TestSubtensorInferShape:
1471+
_NO_OPT_MODE = Mode(linker="py", optimizer=None)
1472+
1473+
@pytest.mark.parametrize(
1474+
"b", [None, 0, 1, 7, 13, -1, -7, -13], ids=lambda x: f"b={x}"
1475+
)
1476+
@pytest.mark.parametrize(
1477+
"a", [None, 0, 1, 7, 13, -1, -7, -13], ids=lambda x: f"a={x}"
1478+
)
1479+
@pytest.mark.parametrize("step", [None, 1, 3, -1, -4], ids=lambda x: f"step={x}")
1480+
def test_slice_infer_shape(self, a, b, step):
1481+
x = vector("x", dtype="int64")
1482+
y = x[a:b:step].shape[0]
1483+
1484+
fg = FunctionGraph(outputs=[y], clone=False)
1485+
rewrite_graph(fg, include=("ShapeOpt", "canonicalize"), clone=False)
1486+
1487+
fn = pytensor.function(
1488+
[x],
1489+
fg.outputs[0],
1490+
trust_input=True,
1491+
mode=self._NO_OPT_MODE,
1492+
on_unused_input="ignore",
1493+
)
1494+
x_full = np.arange(20)
1495+
for n in range(0, 20):
1496+
x_test = x_full[:n]
1497+
assert fn(x_test) == x_test[a:b:step].shape[0], f"failed with {n=}"
1498+
1499+
@pytest.mark.parametrize("a_dtype", (None, "int64", "uint64"))
1500+
@pytest.mark.parametrize("b_dtype", (None, "int64", "uint64"))
1501+
@pytest.mark.parametrize("step_dtype", (None, "int64", "uint64"))
1502+
def test_slice_infer_shape_uint(self, a_dtype, b_dtype, step_dtype):
1503+
a = None if a_dtype is None else scalar(dtype=a_dtype)
1504+
b = None if b_dtype is None else scalar(dtype=b_dtype)
1505+
step = None if step_dtype is None else scalar(dtype=step_dtype)
1506+
x = vector("x", dtype="int64")
1507+
1508+
y = x[a:b:step].shape[0]
1509+
1510+
final_y = rewrite_graph(y, include=("ShapeOpt", "canonicalize"), clone=False)
1511+
assert final_y.dtype == "int64"
1512+
1513+
test_a = None if a is None else 1 if a_dtype.startswith("u") else -1
1514+
test_b = None if b is None else 10 if b_dtype.startswith("u") else -2
1515+
test_step = None if step is None else 2 if step_dtype.startswith("u") else -2
1516+
test_x = np.arange(20)
1517+
1518+
test_dict = {x: test_x}
1519+
if a is not None:
1520+
test_dict[a] = test_a
1521+
if b is not None:
1522+
test_dict[b] = test_b
1523+
if step is not None:
1524+
test_dict[step] = test_step
1525+
1526+
final_y_eval = final_y.eval(
1527+
test_dict, mode=self._NO_OPT_MODE, on_unused_input="ignore"
1528+
)
1529+
assert final_y_eval == test_x[test_a:test_b:test_step].shape[0]
1530+
1531+
14691532
def test_take_basic():
14701533
with pytest.raises(TypeError):
14711534
take(matrix(), lvector(), axis=lscalar())

0 commit comments

Comments
 (0)