Skip to content

Commit 5a39ef6

Browse files
committed
Do not check rewrites based on string representation
1 parent 0a83939 commit 5a39ef6

File tree

2 files changed

+54
-112
lines changed

2 files changed

+54
-112
lines changed

tests/tensor/rewriting/test_basic.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.compile.mode import get_default_mode, get_mode
1313
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1414
from pytensor.configdefaults import config
15+
from pytensor.graph.basic import equal_computations
1516
from pytensor.graph.fg import FunctionGraph
1617
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1718
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -1410,31 +1411,28 @@ def simple_rewrite(self, g):
14101411

14111412
def test_matrix_matrix(self):
14121413
a, b = matrices("ab")
1413-
g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T]))
1414-
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))"
1415-
assert str(g) == sg, (str(g), sg)
1414+
g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T], clone=False))
1415+
assert equal_computations(g.outputs, [dot(b.T, a.T)])
14161416
assert check_stack_trace(g, ops_to_check="all")
14171417

14181418
def test_row_matrix(self):
14191419
a = vector("a")
14201420
b = matrix("b")
14211421
g = rewrite(
1422-
FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]),
1422+
FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T], clone=False),
14231423
level="stabilize",
14241424
)
1425-
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))"
1426-
assert str(g) == sg, (str(g), sg)
1425+
assert equal_computations(g.outputs, [dot(b.T, a.dimshuffle(0, "x"))])
14271426
assert check_stack_trace(g, ops_to_check="all")
14281427

14291428
def test_matrix_col(self):
14301429
a = vector("a")
14311430
b = matrix("b")
14321431
g = rewrite(
1433-
FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]),
1432+
FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T], clone=False),
14341433
level="stabilize",
14351434
)
1436-
sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))"
1437-
assert str(g) == sg, (str(g), sg)
1435+
assert equal_computations(g.outputs, [dot(a.dimshuffle("x", 0), b.T)])
14381436
assert check_stack_trace(g, ops_to_check="all")
14391437

14401438

tests/tensor/rewriting/test_elemwise.py

Lines changed: 47 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import Mode, get_default_mode
1313
from pytensor.configdefaults import config
1414
from pytensor.gradient import grad
15-
from pytensor.graph.basic import Constant
15+
from pytensor.graph.basic import Constant, equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
8686

8787
class TestDimshuffleLift:
8888
def test_double_transpose(self):
89-
x, y, z = inputs()
89+
x, *_ = inputs()
9090
e = ds(ds(x, (1, 0)), (1, 0))
91-
g = FunctionGraph([x], [e])
92-
# TODO FIXME: Construct these graphs and compare them.
93-
assert (
94-
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
95-
)
91+
g = FunctionGraph([x], [e], clone=False)
92+
assert isinstance(g.outputs[0].owner.op, DimShuffle)
9693
dimshuffle_lift.rewrite(g)
97-
assert str(g) == "FunctionGraph(x)"
94+
assert g.outputs[0] is x
9895
# no need to check_stack_trace as graph is supposed to be empty
9996

10097
def test_merge2(self):
101-
x, y, z = inputs()
98+
x, *_ = inputs()
10299
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
103-
g = FunctionGraph([x], [e])
104-
# TODO FIXME: Construct these graphs and compare them.
105-
assert (
106-
str(g)
107-
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
108-
), str(g)
100+
g = FunctionGraph([x], [e], clone=False)
101+
assert len(g.apply_nodes) == 2
109102
dimshuffle_lift.rewrite(g)
110-
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
103+
assert equal_computations(g.outputs, [x.dimshuffle(0, 1, "x", "x")])
111104
# Check stacktrace was copied over correctly after rewrite was applied
112105
assert check_stack_trace(g, ops_to_check="all")
113106

114107
def test_elim3(self):
115108
x, y, z = inputs()
116109
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
117-
g = FunctionGraph([x], [e])
118-
# TODO FIXME: Construct these graphs and compare them.
119-
assert str(g) == (
120-
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
121-
"(InplaceDimShuffle{0,x,1}(x))))"
122-
), str(g)
110+
g = FunctionGraph([x], [e], clone=False)
111+
assert isinstance(g.outputs[0].owner.op, DimShuffle)
123112
dimshuffle_lift.rewrite(g)
124-
assert str(g) == "FunctionGraph(x)", str(g)
113+
assert g.outputs[0] is x
125114
# no need to check_stack_trace as graph is supposed to be empty
126115

127116
def test_lift(self):
128117
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
129118
e = x + y + z
130-
g = FunctionGraph([x, y, z], [e])
131-
132-
# TODO FIXME: Construct these graphs and compare them.
133-
# It does not really matter if the DimShuffles are inplace
134-
# or not.
135-
init_str_g_inplace = (
136-
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
137-
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
138-
)
139-
init_str_g_noinplace = (
140-
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
141-
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
142-
)
143-
assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
144-
145-
rewrite_str_g_inplace = (
146-
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
147-
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
148-
)
149-
rewrite_str_g_noinplace = (
150-
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
151-
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
152-
)
119+
g = FunctionGraph([x, y, z], [e], clone=False)
153120
dimshuffle_lift.rewrite(g)
154-
assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g)
121+
assert equal_computations(
122+
g.outputs,
123+
[(x.dimshuffle("x", "x", 0) + y.dimshuffle("x", 0, 1)) + z],
124+
)
155125
# Check stacktrace was copied over correctly after rewrite was applied
156126
assert check_stack_trace(g, ops_to_check="all")
157127

158128
def test_recursive_lift(self):
159-
v = vector(dtype="float64")
160-
m = matrix(dtype="float64")
129+
v = vector("v", dtype="float64")
130+
m = matrix("m", dtype="float64")
161131
out = ((v + 42) * (m + 84)).T
162-
g = FunctionGraph([v, m], [out])
163-
# TODO FIXME: Construct these graphs and compare them.
164-
init_str_g = (
165-
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
166-
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
167-
"(<TensorType(float64, (?,))>, "
168-
"InplaceDimShuffle{x}(TensorConstant{42}))), "
169-
"Elemwise{add,no_inplace}"
170-
"(<TensorType(float64, (?, ?))>, "
171-
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
172-
)
173-
assert str(g) == init_str_g
174-
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
175-
new_g = FunctionGraph(g.inputs, [new_out])
176-
rewrite_str_g = (
177-
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
178-
"(InplaceDimShuffle{0,x}(<TensorType(float64, (?,))>), "
179-
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
180-
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
181-
"(<TensorType(float64, (?, ?))>), "
182-
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
132+
g = FunctionGraph([v, m], [out], clone=False)
133+
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)
134+
assert equal_computations(
135+
new_out,
136+
[(v.dimshuffle(0, "x") + 42) * (m.T + 84)],
183137
)
184-
assert str(new_g) == rewrite_str_g
185138
# Check stacktrace was copied over correctly after rewrite was applied
139+
new_g = FunctionGraph(g.inputs, new_out, clone=False)
186140
assert check_stack_trace(new_g, ops_to_check="all")
187141

188142
def test_useless_dimshuffle(self):
189-
x, _, _ = inputs()
143+
x, *_ = inputs()
190144
e = ds(x, (0, 1))
191-
g = FunctionGraph([x], [e])
192-
# TODO FIXME: Construct these graphs and compare them.
193-
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
145+
g = FunctionGraph([x], [e], clone=False)
146+
assert isinstance(g.outputs[0].owner.op, DimShuffle)
194147
dimshuffle_lift.rewrite(g)
195-
assert str(g) == "FunctionGraph(x)"
148+
assert g.outputs[0] is x
196149
# Check stacktrace was copied over correctly after rewrite was applied
197150
assert hasattr(g.outputs[0].tag, "trace")
198151

@@ -203,17 +156,10 @@ def test_dimshuffle_on_broadcastable(self):
203156
ds_y = ds(y, (2, 1, 0)) # useless
204157
ds_z = ds(z, (2, 1, 0)) # useful
205158
ds_u = ds(u, ("x")) # useful
206-
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
207-
# TODO FIXME: Construct these graphs and compare them.
208-
assert (
209-
str(g)
210-
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
211-
)
159+
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u], clone=False)
160+
assert len(g.apply_nodes) == 4
212161
dimshuffle_lift.rewrite(g)
213-
assert (
214-
str(g)
215-
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
216-
)
162+
assert equal_computations(g.outputs, [x, y, z.T, u.dimshuffle("x")])
217163
# Check stacktrace was copied over correctly after rewrite was applied
218164
assert hasattr(g.outputs[0].tag, "trace")
219165

@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape():
237183
reshape_dimshuffle_row,
238184
reshape_dimshuffle_col,
239185
],
186+
clone=False,
240187
)
241-
242-
# TODO FIXME: Construct these graphs and compare them.
243-
assert str(g) == (
244-
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
245-
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
246-
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
247-
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
248-
)
188+
assert len(g.apply_nodes) == 4 * 3
249189
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
250190
useless_dimshuffle_in_reshape.rewrite(g)
251-
assert str(g) == (
252-
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
253-
"Reshape{2}(mat, Shape(mat)), "
254-
"Reshape{2}(row, Shape(row)), "
255-
"Reshape{2}(col, Shape(col)))"
191+
assert equal_computations(
192+
g.outputs,
193+
[
194+
reshape(vec, vec.shape),
195+
reshape(mat, mat.shape),
196+
reshape(row, row.shape),
197+
reshape(col, col.shape),
198+
],
256199
)
257-
258200
# Check stacktrace was copied over correctly after rewrite was applied
259201
assert check_stack_trace(g, ops_to_check="all")
260202

261203
# Check that the rewrite does not get applied when the order
262204
# of dimensions has changed.
263205
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
264-
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
265-
str_h = str(h)
206+
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
207+
assert len(h.apply_nodes) == 3
266208
useless_dimshuffle_in_reshape.rewrite(h)
267-
assert str(h) == str_h
209+
assert equal_computations(
210+
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
211+
)
268212

269213

270214
class TestFusion:

0 commit comments

Comments
 (0)