12
12
from pytensor .compile .mode import Mode , get_default_mode
13
13
from pytensor .configdefaults import config
14
14
from pytensor .gradient import grad
15
- from pytensor .graph .basic import Constant
15
+ from pytensor .graph .basic import Constant , equal_computations
16
16
from pytensor .graph .fg import FunctionGraph
17
17
from pytensor .graph .rewriting .basic import check_stack_trace , out2in
18
18
from pytensor .graph .rewriting .db import RewriteDatabaseQuery
@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
86
86
87
87
class TestDimshuffleLift :
88
88
def test_double_transpose (self ):
89
- x , y , z = inputs ()
89
+ x , * _ = inputs ()
90
90
e = ds (ds (x , (1 , 0 )), (1 , 0 ))
91
- g = FunctionGraph ([x ], [e ])
92
- # TODO FIXME: Construct these graphs and compare them.
93
- assert (
94
- str (g ) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
95
- )
91
+ g = FunctionGraph ([x ], [e ], clone = False )
92
+ assert isinstance (g .outputs [0 ].owner .op , DimShuffle )
96
93
dimshuffle_lift .rewrite (g )
97
- assert str ( g ) == "FunctionGraph(x)"
94
+ assert g . outputs [ 0 ] is x
98
95
# no need to check_stack_trace as graph is supposed to be empty
99
96
100
97
def test_merge2 (self ):
101
- x , y , z = inputs ()
98
+ x , * _ = inputs ()
102
99
e = ds (ds (x , (1 , "x" , 0 )), (2 , 0 , "x" , 1 ))
103
- g = FunctionGraph ([x ], [e ])
104
- # TODO FIXME: Construct these graphs and compare them.
105
- assert (
106
- str (g )
107
- == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
108
- ), str (g )
100
+ g = FunctionGraph ([x ], [e ], clone = False )
101
+ assert len (g .apply_nodes ) == 2
109
102
dimshuffle_lift .rewrite (g )
110
- assert str ( g ) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x)) " , str ( g )
103
+ assert equal_computations ( g . outputs , [ x . dimshuffle ( 0 , 1 , "x " , "x" )] )
111
104
# Check stacktrace was copied over correctly after rewrite was applied
112
105
assert check_stack_trace (g , ops_to_check = "all" )
113
106
114
107
def test_elim3 (self ):
115
108
x , y , z = inputs ()
116
109
e = ds (ds (ds (x , (0 , "x" , 1 )), (2 , 0 , "x" , 1 )), (1 , 0 ))
117
- g = FunctionGraph ([x ], [e ])
118
- # TODO FIXME: Construct these graphs and compare them.
119
- assert str (g ) == (
120
- "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
121
- "(InplaceDimShuffle{0,x,1}(x))))"
122
- ), str (g )
110
+ g = FunctionGraph ([x ], [e ], clone = False )
111
+ assert isinstance (g .outputs [0 ].owner .op , DimShuffle )
123
112
dimshuffle_lift .rewrite (g )
124
- assert str ( g ) == "FunctionGraph(x)" , str ( g )
113
+ assert g . outputs [ 0 ] is x
125
114
# no need to check_stack_trace as graph is supposed to be empty
126
115
127
116
def test_lift (self ):
128
117
x , y , z = inputs ([False ] * 1 , [False ] * 2 , [False ] * 3 )
129
118
e = x + y + z
130
- g = FunctionGraph ([x , y , z ], [e ])
131
-
132
- # TODO FIXME: Construct these graphs and compare them.
133
- # It does not really matter if the DimShuffles are inplace
134
- # or not.
135
- init_str_g_inplace = (
136
- "FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
137
- "(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
138
- )
139
- init_str_g_noinplace = (
140
- "FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
141
- "(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
142
- )
143
- assert str (g ) in (init_str_g_inplace , init_str_g_noinplace ), str (g )
144
-
145
- rewrite_str_g_inplace = (
146
- "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
147
- "(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
148
- )
149
- rewrite_str_g_noinplace = (
150
- "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
151
- "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
152
- )
119
+ g = FunctionGraph ([x , y , z ], [e ], clone = False )
153
120
dimshuffle_lift .rewrite (g )
154
- assert str (g ) in (rewrite_str_g_inplace , rewrite_str_g_noinplace ), str (g )
121
+ assert equal_computations (
122
+ g .outputs ,
123
+ [(x .dimshuffle ("x" , "x" , 0 ) + y .dimshuffle ("x" , 0 , 1 )) + z ],
124
+ )
155
125
# Check stacktrace was copied over correctly after rewrite was applied
156
126
assert check_stack_trace (g , ops_to_check = "all" )
157
127
158
128
def test_recursive_lift (self ):
159
- v = vector (dtype = "float64" )
160
- m = matrix (dtype = "float64" )
129
+ v = vector ("v" , dtype = "float64" )
130
+ m = matrix ("m" , dtype = "float64" )
161
131
out = ((v + 42 ) * (m + 84 )).T
162
- g = FunctionGraph ([v , m ], [out ])
163
- # TODO FIXME: Construct these graphs and compare them.
164
- init_str_g = (
165
- "FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
166
- "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
167
- "(<TensorType(float64, (?,))>, "
168
- "InplaceDimShuffle{x}(TensorConstant{42}))), "
169
- "Elemwise{add,no_inplace}"
170
- "(<TensorType(float64, (?, ?))>, "
171
- "InplaceDimShuffle{x,x}(TensorConstant{84})))))"
172
- )
173
- assert str (g ) == init_str_g
174
- new_out = local_dimshuffle_lift .transform (g , g .outputs [0 ].owner )[0 ]
175
- new_g = FunctionGraph (g .inputs , [new_out ])
176
- rewrite_str_g = (
177
- "FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
178
- "(InplaceDimShuffle{0,x}(<TensorType(float64, (?,))>), "
179
- "InplaceDimShuffle{x,x}(TensorConstant{42})), "
180
- "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
181
- "(<TensorType(float64, (?, ?))>), "
182
- "InplaceDimShuffle{x,x}(TensorConstant{84}))))"
132
+ g = FunctionGraph ([v , m ], [out ], clone = False )
133
+ new_out = local_dimshuffle_lift .transform (g , g .outputs [0 ].owner )
134
+ assert equal_computations (
135
+ new_out ,
136
+ [(v .dimshuffle (0 , "x" ) + 42 ) * (m .T + 84 )],
183
137
)
184
- assert str (new_g ) == rewrite_str_g
185
138
# Check stacktrace was copied over correctly after rewrite was applied
139
+ new_g = FunctionGraph (g .inputs , new_out , clone = False )
186
140
assert check_stack_trace (new_g , ops_to_check = "all" )
187
141
188
142
def test_useless_dimshuffle (self ):
189
- x , _ , _ = inputs ()
143
+ x , * _ = inputs ()
190
144
e = ds (x , (0 , 1 ))
191
- g = FunctionGraph ([x ], [e ])
192
- # TODO FIXME: Construct these graphs and compare them.
193
- assert str (g ) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
145
+ g = FunctionGraph ([x ], [e ], clone = False )
146
+ assert isinstance (g .outputs [0 ].owner .op , DimShuffle )
194
147
dimshuffle_lift .rewrite (g )
195
- assert str ( g ) == "FunctionGraph(x)"
148
+ assert g . outputs [ 0 ] is x
196
149
# Check stacktrace was copied over correctly after rewrite was applied
197
150
assert hasattr (g .outputs [0 ].tag , "trace" )
198
151
@@ -203,17 +156,10 @@ def test_dimshuffle_on_broadcastable(self):
203
156
ds_y = ds (y , (2 , 1 , 0 )) # useless
204
157
ds_z = ds (z , (2 , 1 , 0 )) # useful
205
158
ds_u = ds (u , ("x" )) # useful
206
- g = FunctionGraph ([x , y , z , u ], [ds_x , ds_y , ds_z , ds_u ])
207
- # TODO FIXME: Construct these graphs and compare them.
208
- assert (
209
- str (g )
210
- == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
211
- )
159
+ g = FunctionGraph ([x , y , z , u ], [ds_x , ds_y , ds_z , ds_u ], clone = False )
160
+ assert len (g .apply_nodes ) == 4
212
161
dimshuffle_lift .rewrite (g )
213
- assert (
214
- str (g )
215
- == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
216
- )
162
+ assert equal_computations (g .outputs , [x , y , z .T , u .dimshuffle ("x" )])
217
163
# Check stacktrace was copied over correctly after rewrite was applied
218
164
assert hasattr (g .outputs [0 ].tag , "trace" )
219
165
@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape():
237
183
reshape_dimshuffle_row ,
238
184
reshape_dimshuffle_col ,
239
185
],
186
+ clone = False ,
240
187
)
241
-
242
- # TODO FIXME: Construct these graphs and compare them.
243
- assert str (g ) == (
244
- "FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
245
- "Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
246
- "Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
247
- "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
248
- )
188
+ assert len (g .apply_nodes ) == 4 * 3
249
189
useless_dimshuffle_in_reshape = out2in (local_useless_dimshuffle_in_reshape )
250
190
useless_dimshuffle_in_reshape .rewrite (g )
251
- assert str (g ) == (
252
- "FunctionGraph(Reshape{1}(vector, Shape(vector)), "
253
- "Reshape{2}(mat, Shape(mat)), "
254
- "Reshape{2}(row, Shape(row)), "
255
- "Reshape{2}(col, Shape(col)))"
191
+ assert equal_computations (
192
+ g .outputs ,
193
+ [
194
+ reshape (vec , vec .shape ),
195
+ reshape (mat , mat .shape ),
196
+ reshape (row , row .shape ),
197
+ reshape (col , col .shape ),
198
+ ],
256
199
)
257
-
258
200
# Check stacktrace was copied over correctly after rewrite was applied
259
201
assert check_stack_trace (g , ops_to_check = "all" )
260
202
261
203
# Check that the rewrite does not get applied when the order
262
204
# of dimensions has changed.
263
205
reshape_dimshuffle_mat2 = reshape (mat .dimshuffle ("x" , 1 , "x" , 0 ), mat .shape )
264
- h = FunctionGraph ([mat ], [reshape_dimshuffle_mat2 ])
265
- str_h = str ( h )
206
+ h = FunctionGraph ([mat ], [reshape_dimshuffle_mat2 ], clone = False )
207
+ assert len ( h . apply_nodes ) == 3
266
208
useless_dimshuffle_in_reshape .rewrite (h )
267
- assert str (h ) == str_h
209
+ assert equal_computations (
210
+ h .outputs , [reshape (mat .dimshuffle ("x" , 1 , "x" , 0 ), mat .shape )]
211
+ )
268
212
269
213
270
214
class TestFusion :
0 commit comments