|
21 | 21 | pre_greedy_node_rewriter,
|
22 | 22 | )
|
23 | 23 | from pytensor.raise_op import assert_op
|
24 |
| -from pytensor.tensor.math import Dot, add, dot |
| 24 | +from pytensor.tensor.math import Dot, add, dot, exp |
25 | 25 | from pytensor.tensor.rewriting.basic import constant_folding
|
26 | 26 | from pytensor.tensor.subtensor import AdvancedSubtensor
|
27 |
| -from pytensor.tensor.type import matrix, values_eq_approx_always_true |
| 27 | +from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector |
28 | 28 | from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
|
29 | 29 | from tests.graph.utils import (
|
30 | 30 | MyOp,
|
@@ -441,6 +441,23 @@ def test_merge_noinput(self):
|
441 | 441 | assert fg.outputs[0] is fg.outputs[1]
|
442 | 442 | assert fg.outputs[0] is not fg.outputs[2]
|
443 | 443 |
|
| 444 | + @pytest.mark.parametrize("reverse", [False, True]) |
| 445 | + def test_merge_more_specific_types(self, reverse): |
| 446 | + """Check that we choose the most specific static type when merging variables.""" |
| 447 | + |
| 448 | + x1 = vector("x1", shape=(None,)) |
| 449 | + x2 = vector("x2", shape=(500,)) |
| 450 | + |
| 451 | + y1 = exp(x1) |
| 452 | + y2 = exp(x2) |
| 453 | + |
| 454 | + # Simulate case where we find that x2 is equivalent to x1 |
| 455 | + fg = FunctionGraph([x1, x2], [y2, y1] if reverse else [y1, y2], clone=False) |
| 456 | + fg.replace(x1, x2) |
| 457 | + |
| 458 | + MergeOptimizer().rewrite(fg) |
| 459 | + assert fg.outputs == [y2, y2] |
| 460 | + |
444 | 461 |
|
445 | 462 | class TestEquilibrium:
|
446 | 463 | def test_1(self):
|
|
0 commit comments