Skip to content

Commit 7b7ab9e

Browse files
committed
Handle no-op Subtensors in rewrites
1 parent 0088d03 commit 7b7ab9e

File tree

3 files changed

+62
-27
lines changed

3 files changed

+62
-27
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -336,35 +336,46 @@ def local_subtensor_of_dot(fgraph, node):
336336
@node_rewriter([Subtensor])
337337
def local_useless_slice(fgraph, node):
338338
"""
339-
Remove Subtensor of the form X[0, :] -> X[0]
339+
Remove Subtensor of the form:
340+
1. X[0, :] -> X[0]
341+
2. X[:] -> X
342+
340343
"""
341-
if isinstance(node.op, Subtensor):
342-
slices = get_idx_list(node.inputs, node.op.idx_list)
343-
last_slice = len(slices)
344-
for s in slices[::-1]:
345-
# check if slice and then check slice indices
346-
if (
347-
isinstance(s, slice)
348-
and s.start is None
349-
and s.stop is None
350-
and (
351-
s.step is None
352-
or extract_constant(s.step, only_process_constants=True) == 1
353-
)
354-
):
355-
last_slice -= 1
356-
else:
357-
break
358-
# check if we removed something
359-
if last_slice < len(slices):
360-
subtens = Subtensor(slices[:last_slice])
361-
sl_ins = get_slice_elements(
362-
slices[:last_slice], lambda x: isinstance(x, Variable)
344+
idxs = get_idx_list(node.inputs, node.op.idx_list)
345+
346+
if not idxs:
347+
return [node.inputs[0]]
348+
349+
last_useless_slice = len(idxs)
350+
for s in idxs[::-1]:
351+
# check if slice and then check slice indices
352+
if (
353+
isinstance(s, slice)
354+
and s.start is None
355+
and s.stop is None
356+
and (
357+
s.step is None
358+
or extract_constant(s.step, only_process_constants=True) == 1
359+
)
360+
):
361+
last_useless_slice -= 1
362+
else:
363+
break
364+
# check if we removed something
365+
if last_useless_slice < len(idxs):
366+
new_idxs = idxs[:last_useless_slice]
367+
if new_idxs:
368+
new_subtensor = Subtensor(new_idxs)
369+
new_subtensor_inputs = get_slice_elements(
370+
new_idxs, lambda x: isinstance(x, Variable)
363371
)
364-
out = subtens(node.inputs[0], *sl_ins)
372+
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
365373
# Copy over previous output stacktrace
366374
copy_stack_trace(node.outputs, out)
367375
return [out]
376+
else:
377+
# Subtensor is not needed at all
378+
return [node.inputs[0]]
368379

369380

370381
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@@ -747,7 +758,13 @@ def local_subtensor_make_vector(fgraph, node):
747758
make_vector_op = x.owner.op
748759

749760
if isinstance(node.op, Subtensor):
750-
(idx,) = node.op.idx_list
761+
idxs = node.op.idx_list
762+
763+
# Subtensor has no indexes, return make_vector
764+
if not idxs:
765+
return [x]
766+
767+
(idx,) = idxs
751768

752769
if isinstance(idx, (aes.ScalarType, TensorType)):
753770
old_idx, idx = idx, node.inputs[1]
@@ -903,7 +920,11 @@ def local_set_to_inc_subtensor(fgraph, node):
903920
@node_rewriter([Subtensor])
904921
def local_useless_subtensor(fgraph, node):
905922
"""Remove `Subtensor` if it takes the full input."""
906-
# This optimization needs ShapeOpt and fgraph.shape_feature
923+
924+
if not node.op.idx_list:
925+
return [node.inputs[0]]
926+
927+
# The more elaborate optimization needs ShapeOpt and fgraph.shape_feature
907928
if not hasattr(fgraph, "shape_feature"):
908929
return
909930

tests/tensor/rewriting/test_subtensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
12+
from pytensor.graph import FunctionGraph
1213
from pytensor.graph.basic import Constant, Variable, ancestors
1314
from pytensor.graph.rewriting.basic import check_stack_trace
1415
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -21,6 +22,7 @@
2122
from pytensor.tensor.math import Dot, add, dot, exp, sqr
2223
from pytensor.tensor.rewriting.subtensor import (
2324
local_replace_AdvancedSubtensor,
25+
local_subtensor_make_vector,
2426
local_subtensor_shape_constant,
2527
)
2628
from pytensor.tensor.shape import (
@@ -764,6 +766,17 @@ def test_stack_trace(self):
764766
f = function([x, y, z], v_subtensor, mode=mode)
765767
assert check_stack_trace(f, ops_to_check="all")
766768

769+
def test_empty_subtensor(self):
770+
x, y = lscalars("xy")
771+
v = make_vector(x, y)
772+
out = v[()]
773+
774+
fgraph = FunctionGraph(outputs=[out], clone=False)
775+
node = fgraph.outputs[0].owner
776+
assert isinstance(node.op, Subtensor)
777+
778+
assert local_subtensor_make_vector.transform(fgraph, node) == [v]
779+
767780

768781
class TestLocalSubtensorLift:
769782
def test_basic(self):

tests/tensor/test_subtensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def test_0_dims(self):
389389
t = Subtensor([])(n)
390390
assert isinstance(t.owner.op, Subtensor)
391391
self.eval_output_and_check(
392-
t, mode=self.mode.excluding("local_useless_subtensor")
392+
t,
393+
mode=self.mode.excluding("local_useless_subtensor", "local_useless_slice"),
393394
)
394395

395396
def test_err_invalid_2(self):

0 commit comments

Comments
 (0)