Skip to content

Replace export_for_training with torch.export.export #2724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: gh/andrewor14/22/base
Choose a base branch
from
4 changes: 2 additions & 2 deletions docs/source/tutorials_source/pt2e_quant_ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ Here is how you can use ``torch.export`` to export the model:
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()

# for pytorch 2.5 and before
# dynamic_shape API may vary as well
Expand Down Expand Up @@ -501,7 +501,7 @@ Now we can compare the size and model accuracy with baseline model.
# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
quantized_model = torch.export.export(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials_source/pt2e_quant_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ to the post training quantization (PTQ) flow for the most part:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
Expand Down Expand Up @@ -434,7 +433,6 @@ prepared. For example:

.. code:: python

from torch._export import capture_pre_autograd_graph
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
Expand All @@ -443,7 +441,7 @@ prepared. For example:

example_inputs = (torch.rand(2, 3, 224, 224),)
float_model = resnet18(pretrained=False)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
exported_model = torch.export.export(float_model, example_inputs).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
Expand Down
8 changes: 2 additions & 6 deletions docs/source/tutorials_source/pt2e_quant_x86_inductor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
exported_model = export(
model,
example_inputs
)
).module()


Next, we will have the FX Module to be quantized.
Expand Down Expand Up @@ -243,12 +243,10 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
)
from torch.export import export
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer

Expand All @@ -264,9 +262,7 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
m = M()

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
exported_model = export(m, example_inputs)
exported_model = torch.export.export(m, example_inputs).module()
# we get a model with aten ops

# Step 2. quantization-aware training
Expand Down
5 changes: 1 addition & 4 deletions examples/sam2_amg_server/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def aot_compile(
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_training

exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
exported.run_decompositions()
output_path = torch._inductor.aoti_compile_and_package(
exported,
Expand Down
5 changes: 1 addition & 4 deletions examples/sam2_vos_example/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def aot_compile(
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_training

exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
exported.run_decompositions()
output_path = torch._inductor.aoti_compile_and_package(
exported,
Expand Down
5 changes: 1 addition & 4 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,7 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = torch.export.texport_for_training(
m,
example_inputs,
).module()
m = torch.export.export(m, example_inputs).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
4 changes: 1 addition & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,9 +1953,7 @@ def forward(self, x):
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = torch.export.export_for_training(
model, example_inputs, strict=True
).module()
model = torch.export.export(model, example_inputs, strict=True).module()
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int4w_api:
Expand Down
8 changes: 1 addition & 7 deletions test/prototype/inductor/test_int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def _check_common(
)
@config.patch({"freezing": True})
def _test_sdpa_int8_rewriter(self):
from torch.export import export_for_training

import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
Expand Down Expand Up @@ -199,11 +197,7 @@ def _test_sdpa_int8_rewriter(self):
quantizer.set_function_type_qconfig(
torch.matmul, quantizer.get_global_quantization_config()
)
export_model = export_for_training(
mod,
inputs,
strict=True,
).module()
export_model = torch.export.export(mod, inputs, strict=True).module()
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model)
Expand Down
12 changes: 4 additions & 8 deletions test/quantization/pt2e/test_arm_inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.nn as nn
from torch.export import export_for_training
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
)
Expand Down Expand Up @@ -315,10 +314,7 @@ def _test_quantizer(

# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(
m,
example_inputs,
).module()
m = torch.export.export(m, example_inputs).module()

# QAT Model failed to deepcopy
export_model = m if is_qat else copy.deepcopy(m)
Expand Down Expand Up @@ -576,7 +572,7 @@ def _test_linear_unary_helper(
Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer.
"""
use_bias_list = [True, False]
# TODO test for inplace add after refactoring of export_for_training
# TODO test for inplace add after refactoring of export
inplace_list = [False]
if post_op_algo_list is None:
post_op_algo_list = [None]
Expand Down Expand Up @@ -716,7 +712,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
Currently, only add as binary post op is supported.
"""
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
# TODO test for inplace add after refactoring of export_for_training
# TODO test for inplace add after refactoring of export
inplace_add_list = [False]
example_inputs = (torch.randn(2, 16),)
quantizer = ArmInductorQuantizer().set_global(
Expand Down Expand Up @@ -1078,7 +1074,7 @@ def forward(self, x):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs).module()
m = torch.export.export(m, example_inputs).module()
m = prepare_pt2e(m, quantizer)
# Use a linear count instead of names because the names might change, but
# the order should be the same.
Expand Down
3 changes: 1 addition & 2 deletions test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any

import torch
from torch.export import export_for_training
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests

Expand Down Expand Up @@ -110,7 +109,7 @@ def _test_duplicate_dq(

# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(m, example_inputs, strict=True).module()
m = torch.export.export(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _test_metadata_porting(

# program capture
m = copy.deepcopy(m_eager)
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
m = torch.export.export(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
21 changes: 9 additions & 12 deletions test/quantization/pt2e/test_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8

if TORCH_VERSION_AT_LEAST_2_8:
from torch.export import export_for_training

# Increase cache size limit to avoid FailOnRecompileLimitHit error when running multiple tests
# that use export_for_training, which causes many dynamo recompilations
# that use torch.export.export, which causes many dynamo recompilations
if TORCH_VERSION_AT_LEAST_2_8:
torch._dynamo.config.cache_size_limit = 128

Expand All @@ -37,7 +34,7 @@ class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase):
def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()
self._assert_each_node_has_from_node_source(m)
from_node_source_map = self._extract_from_node_source(m)
Expand All @@ -50,7 +47,7 @@ def test_simple(self):
def test_control_flow(self):
m = TestHelperModules.ControlFlow()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
Expand Down Expand Up @@ -93,13 +90,13 @@ def test_deepcopy_preserve_handle(self):
def test_re_export_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
from_node_source_map_ref = self._extract_from_node_source(m)

ep_reexport = export_for_training(m, example_inputs, strict=True)
ep_reexport = torch.export.export(m, example_inputs, strict=True)
m_reexport = ep_reexport.module()

self._assert_each_node_has_from_node_source(m_reexport)
Expand All @@ -110,7 +107,7 @@ def test_re_export_preserve_handle(self):
def test_run_decompositions_same_handle_id(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
Expand All @@ -136,7 +133,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self):

for m in test_models:
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
Expand All @@ -161,7 +158,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
def test_prepare_for_propagation_comparison(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()
m_logger = prepare_for_propagation_comparison(m)
ref = m(*example_inputs)
Expand All @@ -177,7 +174,7 @@ def test_prepare_for_propagation_comparison(self):
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)

ref_from_node_source = self._extract_from_node_source(ep.module())
ref_counter = Counter(ref_from_node_source.values())
Expand Down
Loading
Loading