@@ -156,87 +156,6 @@ def forward(self, x):
156
156
torch ._dynamo .reset ()
157
157
158
158
159
- class TestFP32Accumulation (TestCase ):
160
- def test_fp32_acc (self ):
161
- class FP32Acc (torch .nn .Module ):
162
- def forward (self , input , weight ):
163
- out = torch .ops .aten .mm .default (input , weight )
164
- return out
165
-
166
- inputs = [
167
- torch .rand ((3 , 4 )).cuda (),
168
- torch .rand ((4 , 5 )).cuda (),
169
- ]
170
-
171
- fx_graph = torch .fx .symbolic_trace (FP32Acc ())
172
- expected_ops = {torch .ops .aten ._to_copy .default , torch .ops .aten .mm .default }
173
- unexpected_ops = {}
174
-
175
- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
176
- fx_graph ,
177
- inputs ,
178
- expected_ops = expected_ops ,
179
- unexpected_ops = unexpected_ops ,
180
- min_block_size = 1 ,
181
- use_fp32_acc = True ,
182
- )
183
-
184
- self .assertEqual (
185
- len (unexpected_ops_seen ),
186
- 0 ,
187
- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
188
- )
189
-
190
- self .assertEqual (
191
- len (expected_ops_unseen ),
192
- 0 ,
193
- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
194
- )
195
- torch ._dynamo .reset ()
196
-
197
- def test_fp32_acc_for_addmm (self ):
198
- class FP32Acc (torch .nn .Module ):
199
- def forward (self , input , mat1 , mat2 ):
200
- out = torch .ops .aten .addmm .default (input , mat1 , mat2 , beta = 20 , alpha = 2 )
201
- return out
202
-
203
- inputs = [
204
- torch .rand ((3 , 5 )).cuda (),
205
- torch .rand ((3 , 4 )).cuda (),
206
- torch .rand ((4 , 5 )).cuda (),
207
- ]
208
-
209
- fx_graph = torch .fx .symbolic_trace (FP32Acc ())
210
- expected_ops = {
211
- torch .ops .aten ._to_copy .default ,
212
- torch .ops .aten .mm .default ,
213
- torch .ops .aten .add .Tensor ,
214
- }
215
- unexpected_ops = {}
216
-
217
- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
218
- fx_graph ,
219
- inputs ,
220
- expected_ops = expected_ops ,
221
- unexpected_ops = unexpected_ops ,
222
- min_block_size = 1 ,
223
- use_fp32_acc = True ,
224
- )
225
-
226
- self .assertEqual (
227
- len (unexpected_ops_seen ),
228
- 0 ,
229
- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
230
- )
231
-
232
- self .assertEqual (
233
- len (expected_ops_unseen ),
234
- 0 ,
235
- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
236
- )
237
- torch ._dynamo .reset ()
238
-
239
-
240
159
class TestComplexSubgraph (TestCase ):
241
160
def test_complex_subgraph (self ):
242
161
BATCH = 1
0 commit comments