Skip to content

Commit 0b558d8

Browse files
committed
Retain more precise types in MergeOptimizer
This can avoid some infinite rewrite loops where a SpecifyShape is lifted, removed and then reintroduced at the bottom by the MergeOptimizer
1 parent 4cc13bc commit 0b558d8

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -743,14 +743,22 @@ def apply(self, fgraph):
743743
):
744744
continue
745745

746-
if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type:
747-
res = pairs[0][0].type.convert_variable(pairs[0][1])
748-
749-
# Since the fgraph.replace only checks the convert_variable
750-
# in one way, we change the order in the case that
751-
# convert_variable will not be successful.
752-
if not res:
753-
pairs = [(pairs[0][1], pairs[0][0])]
746+
# Keep the variable with the most specific static type from the pairs
747+
# E.g the second in (TensorType(shape=(None,), TensorType(shape=(5,))
748+
# Otherwise we could end up reverting type inference progress done elsewhere.
749+
for pair_idx in range(len(pairs)):
750+
old, new = pairs[pair_idx]
751+
if old.type == new.type:
752+
continue
753+
# Check if type of new replacement is at least as specific as that of the old variable
754+
if not old.type.is_super(new.type):
755+
# Check the other way around
756+
if new.type.is_super(old.type):
757+
pairs[pair_idx] = (new, old)
758+
else:
759+
# Replacement requires some operation like specify_shape
760+
new_repl = old.type.convert_variable(new)
761+
pairs[pair_idx] = (old, new_repl)
754762

755763
try:
756764
# If they're all `AtomicVariable`s, there's no need to call validate.

tests/graph/rewriting/test_basic.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
pre_greedy_node_rewriter,
2222
)
2323
from pytensor.raise_op import assert_op
24-
from pytensor.tensor.math import Dot, add, dot
24+
from pytensor.tensor.math import Dot, add, dot, exp
2525
from pytensor.tensor.rewriting.basic import constant_folding
2626
from pytensor.tensor.subtensor import AdvancedSubtensor
27-
from pytensor.tensor.type import matrix, values_eq_approx_always_true
27+
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
2828
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
2929
from tests.graph.utils import (
3030
MyOp,
@@ -441,6 +441,23 @@ def test_merge_noinput(self):
441441
assert fg.outputs[0] is fg.outputs[1]
442442
assert fg.outputs[0] is not fg.outputs[2]
443443

444+
@pytest.mark.parametrize("reverse", [False, True])
445+
def test_merge_more_specific_types(self, reverse):
446+
"""Check that we choose the most specific static type when merging variables."""
447+
448+
x1 = vector("x1", shape=(None,))
449+
x2 = vector("x2", shape=(500,))
450+
451+
y1 = exp(x1)
452+
y2 = exp(x2)
453+
454+
# Simulate case where we find that x2 is equivalent to x1
455+
fg = FunctionGraph([x1, x2], [y2, y1] if reverse else [y1, y2], clone=False)
456+
fg.replace(x1, x2)
457+
458+
MergeOptimizer().rewrite(fg)
459+
assert fg.outputs == [y2, y2]
460+
444461

445462
class TestEquilibrium:
446463
def test_1(self):

0 commit comments

Comments
 (0)