@@ -23,9 +23,7 @@ def test_quantizer_conv2d():
2323
2424 example_input = (torch .ones (1 , 4 , 32 , 32 ),)
2525 quantizer = NeutronQuantizer ()
26- graph_module = torch .export .export_for_training (
27- model , example_input , strict = True
28- ).module ()
26+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
2927
3028 # noinspection PyTypeChecker
3129 m = prepare_pt2e (graph_module , quantizer )
@@ -64,9 +62,7 @@ def test_quantizer_linear():
6462
6563 example_input = (torch .ones (10 , 32 ),)
6664 quantizer = NeutronQuantizer ()
67- graph_module = torch .export .export_for_training (
68- model , example_input , strict = True
69- ).module ()
65+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
7066
7167 # noinspection PyTypeChecker
7268 m = prepare_pt2e (graph_module , quantizer )
@@ -105,9 +101,7 @@ def test_quantizer_maxpool2d():
105101
106102 example_input = (torch .ones (1 , 8 , 32 , 32 ),)
107103 quantizer = NeutronQuantizer ()
108- graph_module = torch .export .export_for_training (
109- model , example_input , strict = True
110- ).module ()
104+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
111105
112106 # noinspection PyTypeChecker
113107 m = prepare_pt2e (graph_module , quantizer )
@@ -143,9 +137,7 @@ def test_quantizer_softmax():
143137
144138 example_input = (torch .ones (1 , 10 ),)
145139 quantizer = NeutronQuantizer ()
146- graph_module = torch .export .export_for_training (
147- model , example_input , strict = True
148- ).module ()
140+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
149141
150142 # noinspection PyTypeChecker
151143 m = prepare_pt2e (graph_module , quantizer )
@@ -182,9 +174,7 @@ def test_quantizer_single_maxpool2d():
182174
183175 example_input = (torch .ones (1 , 4 , 32 , 32 ),)
184176 quantizer = NeutronQuantizer ()
185- graph_module = torch .export .export_for_training (
186- model , example_input , strict = True
187- ).module ()
177+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
188178
189179 # noinspection PyTypeChecker
190180 m = prepare_pt2e (graph_module , quantizer )
@@ -206,9 +196,7 @@ def test_quantizer_conv2d_relu():
206196
207197 example_input = (torch .ones (1 , 4 , 32 , 32 ),)
208198 quantizer = NeutronQuantizer ()
209- graph_module = torch .export .export_for_training (
210- model , example_input , strict = True
211- ).module ()
199+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
212200
213201 # noinspection PyTypeChecker
214202 m = prepare_pt2e (graph_module , quantizer )
@@ -231,9 +219,7 @@ def test_quantizer_conv2d_avg_pool2d():
231219
232220 example_input = (torch .ones (1 , 4 , 16 , 16 ),)
233221 quantizer = NeutronQuantizer ()
234- graph_module = torch .export .export_for_training (
235- model , example_input , strict = True
236- ).module ()
222+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
237223
238224 # noinspection PyTypeChecker
239225 m = prepare_pt2e (graph_module , quantizer )
@@ -256,9 +242,7 @@ def test_quantizer_conv2d_permute():
256242
257243 example_input = (torch .ones (1 , 4 , 16 , 16 ),)
258244 quantizer = NeutronQuantizer ()
259- graph_module = torch .export .export_for_training (
260- model , example_input , strict = True
261- ).module ()
245+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
262246
263247 # noinspection PyTypeChecker
264248 m = prepare_pt2e (graph_module , quantizer )
@@ -285,9 +269,7 @@ def test_multiple_shared_spec_ops_in_row():
285269
286270 example_input = (torch .ones (1 , 3 , 64 , 64 ),)
287271 quantizer = NeutronQuantizer ()
288- graph_module = torch .export .export_for_training (
289- model , example_input , strict = True
290- ).module ()
272+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
291273
292274 # noinspection PyTypeChecker
293275 m = prepare_pt2e (graph_module , quantizer )
@@ -321,9 +303,7 @@ def test_quantizers_order_invariance():
321303 example_input = (torch .ones (1 , 4 , 64 , 64 ),)
322304 quantizer = NeutronQuantizer ()
323305
324- graph_module = torch .export .export_for_training (
325- model , example_input , strict = True
326- ).module ()
306+ graph_module = torch .export .export (model , example_input , strict = True ).module ()
327307
328308 m = prepare_pt2e (deepcopy (graph_module ), quantizer )
329309 m (* example_input )
0 commit comments