@@ -23,9 +23,7 @@ def test_quantizer_conv2d():
23
23
24
24
example_input = (torch .ones (1 , 4 , 32 , 32 ),)
25
25
quantizer = NeutronQuantizer ()
26
- graph_module = torch .export .export_for_training (
27
- model , example_input , strict = True
28
- ).module ()
26
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
29
27
30
28
# noinspection PyTypeChecker
31
29
m = prepare_pt2e (graph_module , quantizer )
@@ -64,9 +62,7 @@ def test_quantizer_linear():
64
62
65
63
example_input = (torch .ones (10 , 32 ),)
66
64
quantizer = NeutronQuantizer ()
67
- graph_module = torch .export .export_for_training (
68
- model , example_input , strict = True
69
- ).module ()
65
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
70
66
71
67
# noinspection PyTypeChecker
72
68
m = prepare_pt2e (graph_module , quantizer )
@@ -105,9 +101,7 @@ def test_quantizer_maxpool2d():
105
101
106
102
example_input = (torch .ones (1 , 8 , 32 , 32 ),)
107
103
quantizer = NeutronQuantizer ()
108
- graph_module = torch .export .export_for_training (
109
- model , example_input , strict = True
110
- ).module ()
104
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
111
105
112
106
# noinspection PyTypeChecker
113
107
m = prepare_pt2e (graph_module , quantizer )
@@ -143,9 +137,7 @@ def test_quantizer_softmax():
143
137
144
138
example_input = (torch .ones (1 , 10 ),)
145
139
quantizer = NeutronQuantizer ()
146
- graph_module = torch .export .export_for_training (
147
- model , example_input , strict = True
148
- ).module ()
140
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
149
141
150
142
# noinspection PyTypeChecker
151
143
m = prepare_pt2e (graph_module , quantizer )
@@ -182,9 +174,7 @@ def test_quantizer_single_maxpool2d():
182
174
183
175
example_input = (torch .ones (1 , 4 , 32 , 32 ),)
184
176
quantizer = NeutronQuantizer ()
185
- graph_module = torch .export .export_for_training (
186
- model , example_input , strict = True
187
- ).module ()
177
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
188
178
189
179
# noinspection PyTypeChecker
190
180
m = prepare_pt2e (graph_module , quantizer )
@@ -206,9 +196,7 @@ def test_quantizer_conv2d_relu():
206
196
207
197
example_input = (torch .ones (1 , 4 , 32 , 32 ),)
208
198
quantizer = NeutronQuantizer ()
209
- graph_module = torch .export .export_for_training (
210
- model , example_input , strict = True
211
- ).module ()
199
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
212
200
213
201
# noinspection PyTypeChecker
214
202
m = prepare_pt2e (graph_module , quantizer )
@@ -231,9 +219,7 @@ def test_quantizer_conv2d_avg_pool2d():
231
219
232
220
example_input = (torch .ones (1 , 4 , 16 , 16 ),)
233
221
quantizer = NeutronQuantizer ()
234
- graph_module = torch .export .export_for_training (
235
- model , example_input , strict = True
236
- ).module ()
222
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
237
223
238
224
# noinspection PyTypeChecker
239
225
m = prepare_pt2e (graph_module , quantizer )
@@ -256,9 +242,7 @@ def test_quantizer_conv2d_permute():
256
242
257
243
example_input = (torch .ones (1 , 4 , 16 , 16 ),)
258
244
quantizer = NeutronQuantizer ()
259
- graph_module = torch .export .export_for_training (
260
- model , example_input , strict = True
261
- ).module ()
245
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
262
246
263
247
# noinspection PyTypeChecker
264
248
m = prepare_pt2e (graph_module , quantizer )
@@ -285,9 +269,7 @@ def test_multiple_shared_spec_ops_in_row():
285
269
286
270
example_input = (torch .ones (1 , 3 , 64 , 64 ),)
287
271
quantizer = NeutronQuantizer ()
288
- graph_module = torch .export .export_for_training (
289
- model , example_input , strict = True
290
- ).module ()
272
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
291
273
292
274
# noinspection PyTypeChecker
293
275
m = prepare_pt2e (graph_module , quantizer )
@@ -321,9 +303,7 @@ def test_quantizers_order_invariance():
321
303
example_input = (torch .ones (1 , 4 , 64 , 64 ),)
322
304
quantizer = NeutronQuantizer ()
323
305
324
- graph_module = torch .export .export_for_training (
325
- model , example_input , strict = True
326
- ).module ()
306
+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
327
307
328
308
m = prepare_pt2e (deepcopy (graph_module ), quantizer )
329
309
m (* example_input )
0 commit comments