Skip to content

Commit 9c67384

Browse files
Arm backend: Replace .export_for_training with .export (#13280)
Signed-off-by: Adrian Lundell <[email protected]>
1 parent 90fd962 commit 9c67384

File tree

5 files changed

+12
-17
lines changed

5 files changed

+12
-17
lines changed

backends/arm/test/misc/test_extract_io_params_tosa.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,12 @@ def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner
6060
operator_config = get_symmetric_quantization_config(is_qat=True)
6161
quantizer.set_global(operator_config)
6262

63-
exported = torch.export.export_for_training(
64-
mod, copy.deepcopy(example_inputs), strict=True
65-
)
63+
exported = torch.export.export(mod, copy.deepcopy(example_inputs), strict=True)
6664
prepared = prepare_pt2e(exported.module(), quantizer)
6765
_ = prepared(*example_inputs)
6866

6967
converted = convert_pt2e(prepared)
70-
final_export = torch.export.export_for_training(
71-
converted, example_inputs, strict=True
72-
)
68+
final_export = torch.export.export(converted, example_inputs, strict=True)
7369
partitioner = partitioner_cls(compile_spec)
7470
edge_prog = to_edge_transform_and_lower(final_export, partitioner=[partitioner])
7571

backends/cortex_m/test/test_replace_quant_nodes.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -16,7 +17,7 @@
1617
ReplaceQuantNodesPass,
1718
)
1819
from executorch.exir.dialects._ops import ops as exir_ops
19-
from torch.export import export, export_for_training
20+
from torch.export import export
2021
from torch.fx import GraphModule
2122
from torchao.quantization.pt2e.observer import HistogramObserver
2223
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
@@ -125,9 +126,7 @@ def forward(self, x):
125126
example_inputs = (torch.randn(10, 11, 12),)
126127

127128
# Step 1: Export and quantize the model
128-
exported_model = export_for_training(
129-
model.eval(), example_inputs, strict=True
130-
).module()
129+
exported_model = export(model.eval(), example_inputs, strict=True).module()
131130
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
132131
quantized_model = convert_pt2e(prepared_model)
133132

docs/source/backends-arm-ethos-u.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ compile_spec = ArmCompileSpecBuilder().ethosu_compile_spec(
5050
).build()
5151

5252
# Post training quantization
53-
graph_module = torch.export.export_for_training(mobilenet_v2, example_inputs).module()
53+
graph_module = torch.export.export(mobilenet_v2, example_inputs).module()
5454
quantizer = EthosUQuantizer(compile_spec)
5555
operator_config = get_symmetric_quantization_config(is_per_channel=False)
5656
quantizer.set_global(operator_config)
5757
graph_module = prepare_pt2e(graph_module, quantizer)
5858
graph_module(*example_inputs)
5959
graph_module = convert_pt2e(graph_module)
60-
exported_program = torch.export.export_for_training(graph_module, example_inputs)
60+
exported_program = torch.export.export(graph_module, example_inputs)
6161

6262
# Lower the exported program to the Ethos-U backend and save pte file.
6363
edge_program_manager = to_edge_transform_and_lower(

examples/arm/aot_arm_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
710710
args.evaluate_config,
711711
)
712712
# Wrap quantized model back into an exported_program
713-
exported_program = torch.export.export_for_training(
713+
exported_program = torch.export.export(
714714
model_int8, example_inputs, strict=args.strict_export
715715
)
716716

@@ -803,9 +803,9 @@ def transform_for_cortex_m_backend(edge):
803803
)
804804
model = original_model.eval()
805805

806-
# export_for_training under the assumption we quantize, the exported form also works
806+
# export under the assumption we quantize, the exported form also works
807807
# in to_edge if we don't quantize
808-
exported_program = torch.export.export_for_training(
808+
exported_program = torch.export.export(
809809
model, example_inputs, strict=args.strict_export
810810
)
811811
model = exported_program.module()

examples/arm/ethos_u_minimal_example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"\n",
5959
"model = Add()\n",
6060
"model = model.eval()\n",
61-
"exported_program = torch.export.export_for_training(model, example_inputs)\n",
61+
"exported_program = torch.export.export(model, example_inputs)\n",
6262
"graph_module = exported_program.module()\n",
6363
"\n",
6464
"_ = graph_module.print_readable()"
@@ -114,7 +114,7 @@
114114
"_ = quantized_graph_module.print_readable()\n",
115115
"\n",
116116
"# Create a new exported program using the quantized_graph_module\n",
117-
"quantized_exported_program = torch.export.export_for_training(quantized_graph_module, example_inputs)"
117+
"quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)"
118118
]
119119
},
120120
{

0 commit comments

Comments
 (0)