Skip to content

Commit a03dfc6

Browse files
committed
Replace export_for_training with torch.export.export
**Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI ghstack-source-id: d48c367 Pull Request resolved: #2724
1 parent de3e841 commit a03dfc6

23 files changed

+74
-109
lines changed

docs/source/tutorials_source/pt2e_quant_ptq.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ Here is how you can use ``torch.export`` to export the model:
362362
{0: torch.export.Dim("dim")} if i == 0 else None
363363
for i in range(len(example_inputs))
364364
)
365-
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
365+
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
366366
367367
# for pytorch 2.5 and before
368368
# dynamic_shape API may vary as well
@@ -501,7 +501,7 @@ Now we can compare the size and model accuracy with baseline model.
501501
# Quantized model size and accuracy
502502
print("Size of model after quantization")
503503
# export again to remove unused weights
504-
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
504+
quantized_model = torch.export.export(quantized_model, example_inputs).module()
505505
print_size_of_model(quantized_model)
506506
507507
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)

docs/source/tutorials_source/pt2e_quant_qat.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ to the post training quantization (PTQ) flow for the most part:
1313
.. code:: python
1414
1515
import torch
16-
from torch._export import capture_pre_autograd_graph
1716
from torchao.quantization.pt2e.quantize_pt2e import (
1817
prepare_qat_pt2e,
1918
convert_pt2e,
@@ -434,7 +433,6 @@ prepared. For example:
434433

435434
.. code:: python
436435
437-
from torch._export import capture_pre_autograd_graph
438436
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
439437
get_symmetric_quantization_config,
440438
XNNPACKQuantizer,
@@ -443,7 +441,7 @@ prepared. For example:
443441
444442
example_inputs = (torch.rand(2, 3, 224, 224),)
445443
float_model = resnet18(pretrained=False)
446-
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
444+
exported_model = torch.export.export(float_model, example_inputs).module()
447445
quantizer = XNNPACKQuantizer()
448446
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
449447
prepared_model = prepare_qat_pt2e(exported_model, quantizer)

docs/source/tutorials_source/pt2e_quant_x86_inductor.rst

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
105105
exported_model = export(
106106
model,
107107
example_inputs
108-
)
108+
).module()
109109

110110

111111
Next, we will have the FX Module to be quantized.
@@ -243,12 +243,10 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
243243
.. code:: python
244244
245245
import torch
246-
from torch._export import capture_pre_autograd_graph
247246
from torchao.quantization.pt2e.quantize_pt2e import (
248247
prepare_qat_pt2e,
249248
convert_pt2e,
250249
)
251-
from torch.export import export
252250
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
253251
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer
254252
@@ -264,9 +262,7 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
264262
m = M()
265263
266264
# Step 1. program capture
267-
# NOTE: this API will be updated to torch.export API in the future, but the captured
268-
# result shoud mostly stay the same
269-
exported_model = export(m, example_inputs)
265+
exported_model = torch.export.export(m, example_inputs).module()
270266
# we get a model with aten ops
271267
272268
# Step 2. quantization-aware training

examples/sam2_amg_server/compile_export_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,7 @@ def aot_compile(
118118
"max_autotune": True,
119119
"triton.cudagraphs": True,
120120
}
121-
122-
from torch.export import export_for_training
123-
124-
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
121+
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
125122
exported.run_decompositions()
126123
output_path = torch._inductor.aoti_compile_and_package(
127124
exported,

examples/sam2_vos_example/compile_export_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ def aot_compile(
8181
"max_autotune": True,
8282
"triton.cudagraphs": True,
8383
}
84-
85-
from torch.export import export_for_training
86-
87-
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
84+
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
8885
exported.run_decompositions()
8986
output_path = torch._inductor.aoti_compile_and_package(
9087
exported,

test/dtypes/test_uint4.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,7 @@ def forward(self, x):
242242

243243
# program capture
244244
m = copy.deepcopy(m_eager)
245-
m = torch.export.texport_for_training(
246-
m,
247-
example_inputs,
248-
).module()
245+
m = torch.export.export(m, example_inputs).module()
249246

250247
m = prepare_pt2e(m, quantizer)
251248
# Calibrate

test/integration/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,11 +2062,11 @@ def forward(self, x):
20622062
# we can re-enable this after non-functional IR is enabled in export
20632063
# model = torch.export.export(model, example_inputs).module()
20642064
if TORCH_VERSION_AT_LEAST_2_5:
2065-
model = torch.export.export_for_training(
2065+
model = torch.export.export(
20662066
model, example_inputs, strict=True
20672067
).module()
20682068
else:
2069-
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
2069+
raise ValueError("should not be here")
20702070
after_export = model(x)
20712071
self.assertTrue(torch.equal(after_export, ref))
20722072
if api is _int8da_int4w_api:

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ def _check_common(
157157
)
158158
@config.patch({"freezing": True})
159159
def _test_sdpa_int8_rewriter(self):
160-
from torch.export import export_for_training
161-
162160
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
163161
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
164162
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
@@ -199,11 +197,7 @@ def _test_sdpa_int8_rewriter(self):
199197
quantizer.set_function_type_qconfig(
200198
torch.matmul, quantizer.get_global_quantization_config()
201199
)
202-
export_model = export_for_training(
203-
mod,
204-
inputs,
205-
strict=True,
206-
).module()
200+
export_model = torch.export.export(mod, inputs, strict=True).module()
207201
prepare_model = prepare_pt2e(export_model, quantizer)
208202
prepare_model(*inputs)
209203
convert_model = convert_pt2e(prepare_model)

test/quantization/pt2e/test_arm_inductor_quantizer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.nn as nn
17-
from torch.export import export_for_training
1817
from torch.testing._internal.common_quantization import (
1918
NodeSpec as ns,
2019
)
@@ -315,10 +314,7 @@ def _test_quantizer(
315314

316315
# program capture
317316
m = copy.deepcopy(m_eager)
318-
m = export_for_training(
319-
m,
320-
example_inputs,
321-
).module()
317+
m = torch.export.export(m, example_inputs).module()
322318

323319
# QAT Model failed to deepcopy
324320
export_model = m if is_qat else copy.deepcopy(m)
@@ -576,7 +572,7 @@ def _test_linear_unary_helper(
576572
Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer.
577573
"""
578574
use_bias_list = [True, False]
579-
# TODO test for inplace add after refactoring of export_for_training
575+
# TODO test for inplace add after refactoring of export
580576
inplace_list = [False]
581577
if post_op_algo_list is None:
582578
post_op_algo_list = [None]
@@ -716,7 +712,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
716712
Currently, only add as binary post op is supported.
717713
"""
718714
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
719-
# TODO test for inplace add after refactoring of export_for_training
715+
# TODO test for inplace add after refactoring of export
720716
inplace_add_list = [False]
721717
example_inputs = (torch.randn(2, 16),)
722718
quantizer = ArmInductorQuantizer().set_global(
@@ -1078,7 +1074,7 @@ def forward(self, x):
10781074
)
10791075
example_inputs = (torch.randn(2, 2),)
10801076
m = M().eval()
1081-
m = export_for_training(m, example_inputs).module()
1077+
m = torch.export.export(m, example_inputs).module()
10821078
m = prepare_pt2e(m, quantizer)
10831079
# Use a linear count instead of names because the names might change, but
10841080
# the order should be the same.

test/quantization/pt2e/test_duplicate_dq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import Any
1212

1313
import torch
14-
from torch.export import export_for_training
1514
from torch.testing._internal.common_quantization import QuantizationTestCase
1615
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
1716

@@ -110,7 +109,7 @@ def _test_duplicate_dq(
110109

111110
# program capture
112111
m = copy.deepcopy(m_eager)
113-
m = export_for_training(m, example_inputs, strict=True).module()
112+
m = torch.export.export(m, example_inputs, strict=True).module()
114113

115114
m = prepare_pt2e(m, quantizer)
116115
# Calibrate

0 commit comments

Comments
 (0)