Skip to content

Replace export_for_training with torch.export.export #2724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e1d7de3
Deprecate old TORCH_VERSION variables
andrewor14 Aug 8, 2025
922fc3e
Update on "Deprecate old TORCH_VERSION variables"
andrewor14 Aug 8, 2025
fc7dffe
Drop support for PyTorch 2.5 and before
andrewor14 Aug 8, 2025
42f9081
Remove old `change_linear_weights_to_*` APIs
andrewor14 Aug 8, 2025
83fb739
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
afedb9f
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
4697b22
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
da64318
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
ac6c78f
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
4d93ac7
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
d6c4715
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
d4762be
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
2bc14bd
Replace `export_for_training` with `torch.export.export`
andrewor14 Aug 8, 2025
e87d7b2
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 8, 2025
8a8843f
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 8, 2025
827d81b
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
10066ca
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
a15f8fa
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
b3cf7b4
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
a5f8040
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
7792da8
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
8caf0a5
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
c2ffa16
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
eda2df2
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
1f0de23
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
69c6c34
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 13, 2025
eed1c7c
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 13, 2025
d687491
Merge branch 'main' into gh/andrewor14/22/head
andrewor14 Aug 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ jobs:
fail-fast: false
matrix:
include:
- name: CUDA 2.5.1
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/"
- name: CUDA 2.6
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.6.0'
Expand All @@ -77,13 +71,13 @@ jobs:
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
dev-requirements-overrides: ""
- name: CUDA 2.8
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.8.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
dev-requirements-overrides: ""

- name: CPU 2.5.1
runs-on: linux.4xlarge
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/"
- name: CPU 2.6
runs-on: linux.4xlarge
torch-spec: 'torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu'
Expand All @@ -96,6 +90,12 @@ jobs:
gpu-arch-type: "cpu"
gpu-arch-version: ""
dev-requirements-overrides: ""
- name: CPU 2.8
runs-on: linux.4xlarge
torch-spec: 'torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
dev-requirements-overrides: ""

uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down
51 changes: 15 additions & 36 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,26 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)


def _int8wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod, **kwargs)
quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False)


def _int8da_int8w_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(
mod,
int8_dynamic_activation_int8_weight(**kwargs),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod, **kwargs)
quantize_(
mod,
int8_dynamic_activation_int8_weight(**kwargs),
set_inductor_config=False,
)


def _int4wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
kwargs_copy = kwargs.copy()
if "groupsize" in kwargs_copy:
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
kwargs_copy = kwargs.copy()
if "groupsize" in kwargs_copy:
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False)


class ToyLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -95,11 +75,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
"""
from torchao.quantization.quant_api import (
_get_subclass_inserter,
_in_features_greater_than_16,
_is_linear,
)
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

def _in_features_greater_than_16(mod, *args):
return hasattr(mod, "in_features") and mod.in_features > 16

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
Expand Down Expand Up @@ -195,21 +177,19 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
)


if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
if __name__ == "__main__" and torch.cuda.is_available():
all_shapes = [
(20, 2048, 2048),
]

print("_int8da_int8w_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors

for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(
_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K
)

print("_int8wo_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors

for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(
Expand All @@ -218,7 +198,6 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):

print("_int4wo_api")
kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors

for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(
Expand Down
4 changes: 0 additions & 4 deletions docs/source/pretraining.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ Below is a code snippet showing how to use it:
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_linear import Float8Linear
from torchao.float8 import convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
m = nn.Sequential(
Expand Down
6 changes: 0 additions & 6 deletions docs/source/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,10 @@ it is also much faster!
.. code:: py

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
benchmark_model,
unwrap_tensor_subclass,
)

# Temporary workaround for tensor subclass + torch.compile
# Only needed for torch version < 2.5
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)

num_runs = 100
torch._dynamo.reset()
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials_source/pt2e_quant_ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ Here is how you can use ``torch.export`` to export the model:
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()

# for pytorch 2.5 and before
# dynamic_shape API may vary as well
Expand Down Expand Up @@ -501,7 +501,7 @@ Now we can compare the size and model accuracy with baseline model.
# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
quantized_model = torch.export.export(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials_source/pt2e_quant_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ to the post training quantization (PTQ) flow for the most part:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
Expand Down Expand Up @@ -434,7 +433,6 @@ prepared. For example:

.. code:: python

from torch._export import capture_pre_autograd_graph
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
Expand All @@ -443,7 +441,7 @@ prepared. For example:

example_inputs = (torch.rand(2, 3, 224, 224),)
float_model = resnet18(pretrained=False)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
exported_model = torch.export.export(float_model, example_inputs).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
Expand Down
8 changes: 2 additions & 6 deletions docs/source/tutorials_source/pt2e_quant_x86_inductor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
exported_model = export(
model,
example_inputs
)
).module()


Next, we will have the FX Module to be quantized.
Expand Down Expand Up @@ -243,12 +243,10 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
)
from torch.export import export
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer

Expand All @@ -264,9 +262,7 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
m = M()

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
exported_model = export(m, example_inputs)
exported_model = torch.export.export(m, example_inputs).module()
# we get a model with aten ops

# Step 2. quantization-aware training
Expand Down
5 changes: 1 addition & 4 deletions examples/sam2_amg_server/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def aot_compile(
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_training

exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
exported.run_decompositions()
output_path = torch._inductor.aoti_compile_and_package(
exported,
Expand Down
5 changes: 1 addition & 4 deletions examples/sam2_vos_example/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def aot_compile(
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_training

exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
exported.run_decompositions()
output_path = torch._inductor.aoti_compile_and_package(
exported,
Expand Down
11 changes: 1 addition & 10 deletions scripts/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
import torch

from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
benchmark_model,
unwrap_tensor_subclass,
)
from torchao.utils import benchmark_model

# ================
# | Set up model |
Expand Down Expand Up @@ -50,11 +46,6 @@ def forward(self, x):
# | Benchmark |
# =============

# Temporary workaround for tensor subclass + torch.compile
# Only needed for torch version < 2.5
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)

num_runs = 100
torch._dynamo.reset()
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
Expand Down
5 changes: 1 addition & 4 deletions test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
UIntXWeightOnlyConfig,
)
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

# Define test configurations as fixtures
configs = [
Expand Down Expand Up @@ -85,11 +84,9 @@
),
AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING),
AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"),
FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]),
]

if TORCH_VERSION_AT_LEAST_2_6:
configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])]


# Create ids for better test naming
def get_config_ids(configs):
Expand Down
7 changes: 1 addition & 6 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
check_cpu_version,
check_xpu_version,
is_fbcode,
Expand Down Expand Up @@ -151,11 +150,7 @@ def test_weights_only(self):
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)
_ = torch.load(f, weights_only=True)

@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
Expand Down
9 changes: 0 additions & 9 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pytest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import copy
import io
import random
Expand Down
5 changes: 0 additions & 5 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.quant_api import quantize_
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

if common_utils.SEED is None:
common_utils.SEED = 1234
Expand Down Expand Up @@ -127,10 +126,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

dn_dist(up_dist(input_dtensor))

if not TORCH_VERSION_AT_LEAST_2_6:
# Need torch 2.6 to support compiled tensor parallelism
return

up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
Expand Down
6 changes: 1 addition & 5 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
quantize_,
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
from torchao.utils import is_fbcode

_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_Floatx_DTYPES = [(3, 2), (2, 2)]
Expand Down Expand Up @@ -107,10 +107,6 @@ def test_to_copy_device(self, ebits, mbits):
assert floatx_tensor_impl.device.type == "cpu"

@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_5,
reason="quantization only works with torch.compile for 2.5+",
)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
Expand Down
Loading
Loading