Skip to content

Commit 127d625

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Replace export_for_training with export (#14073)
Summary: Pull Request resolved: #14073 export_for_training is deprecated, so replace it with export. Differential Revision: D81936329
1 parent a90e907 commit 127d625

File tree

37 files changed

+107
-177
lines changed

37 files changed

+107
-177
lines changed

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.export import export_for_training
18+
from torch.export import export
1919
from torchao.quantization.pt2e.quantize_pt2e import (
2020
convert_pt2e,
2121
prepare_pt2e,
@@ -32,9 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(
36-
model, example_inputs, strict=True
37-
).module()
35+
pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module()
3836

3937
quantization_config = LinearQuantizerConfig.from_dict(
4038
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def lower_module_and_test_output(
206206

207207
expected_output = model(*sample_inputs)
208208

209-
model = torch.export.export_for_training(
209+
model = torch.export.export(
210210
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

backends/cortex_m/test/test_quantize_op_fusion_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_node_args,
2424
)
2525
from executorch.exir.dialects._ops import ops as exir_ops
26-
from torch.export import export, export_for_training
26+
from torch.export import export, export
2727
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2828

2929

@@ -42,9 +42,7 @@ def _prepare_quantized_model(self, model_class):
4242
model = model_class()
4343

4444
# Export and quantize
45-
exported_model = export_for_training(
46-
model.eval(), self.example_inputs, strict=True
47-
).module()
45+
exported_model = export(model.eval(), self.example_inputs, strict=True).module()
4846
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
4947
quantized_model = convert_pt2e(prepared_model)
5048

@@ -242,9 +240,7 @@ def forward(self, x, y):
242240
inputs = (torch.randn(shape), torch.randn(shape))
243241

244242
model = SingleAddModel()
245-
exported_model = export_for_training(
246-
model.eval(), inputs, strict=True
247-
).module()
243+
exported_model = export(model.eval(), inputs, strict=True).module()
248244
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
249245
quantized_model = convert_pt2e(prepared_model)
250246

backends/example/test_example_delegate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def get_example_inputs():
4646
)
4747

4848
m = model.eval()
49-
m = torch.export.export_for_training(
50-
m, copy.deepcopy(example_inputs), strict=True
51-
).module()
49+
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
5250
# print("original model:", m)
5351
quantizer = ExampleQuantizer()
5452
# quantizer = XNNPACKQuantizer()
@@ -84,9 +82,7 @@ def test_delegate_mobilenet_v2(self):
8482
)
8583

8684
m = model.eval()
87-
m = torch.export.export_for_training(
88-
m, copy.deepcopy(example_inputs), strict=True
89-
).module()
85+
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
9086
quantizer = ExampleQuantizer()
9187

9288
m = prepare_pt2e(m, quantizer)

backends/mediatek/quantizer/annotator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch._ops import OpOverload
1111
from torch._subclasses import FakeTensor
1212

13-
from torch.export import export_for_training
13+
from torch.export import export
1414
from torch.fx import Graph, Node
1515
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
1616
SubgraphMatcherWithNameNodeMap,
@@ -158,9 +158,7 @@ def forward(self, x):
158158
return norm, {}
159159

160160
for pattern_cls in (ExecuTorchPattern, MTKPattern):
161-
pattern_gm = export_for_training(
162-
pattern_cls(), (torch.randn(3, 3),), strict=True
163-
).module()
161+
pattern_gm = export(pattern_cls(), (torch.randn(3, 3),), strict=True).module()
164162
matcher = SubgraphMatcherWithNameNodeMap(
165163
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
166164
)

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def get_prepared_qat_module(
576576
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
577577
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
578578
) -> torch.fx.GraphModule:
579-
m = torch.export.export_for_training(module, inputs, strict=True).module()
579+
m = torch.export.export(module, inputs, strict=True).module()
580580

581581
quantizer = make_quantizer(
582582
quant_dtype=quant_dtype,

backends/test/harness/stages/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
DuplicateDynamicQuantChainPass,
88
)
99

10-
from torch.export import export_for_training
10+
from torch.export import export
1111

1212
from torchao.quantization.pt2e.quantize_pt2e import (
1313
convert_pt2e,
@@ -47,7 +47,7 @@ def run(
4747
assert inputs is not None
4848
if self.is_qat:
4949
artifact.train()
50-
captured_graph = export_for_training(artifact, inputs, strict=True).module()
50+
captured_graph = export(artifact, inputs, strict=True).module()
5151

5252
assert isinstance(captured_graph, torch.fx.GraphModule)
5353

backends/transforms/test/test_duplicate_dynamic_quant_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _test_duplicate_chain(
5858

5959
# program capture
6060
m = copy.deepcopy(m_eager)
61-
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
61+
m = torch.export.export(m, example_inputs, strict=True).module()
6262

6363
m = prepare_pt2e(m, quantizer)
6464
# Calibrate

backends/vulkan/test/test_vulkan_passes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def quantize_and_lower_module(
5858
_check_ir_validity=False,
5959
)
6060

61-
program = torch.export.export_for_training(
61+
program = torch.export.export(
6262
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
6363
).module()
6464

@@ -95,7 +95,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) ->
9595

9696

9797
class TestVulkanPasses(unittest.TestCase):
98-
9998
def test_fuse_int8pack_mm(self):
10099
K = 256
101100
N = 256
@@ -184,7 +183,7 @@ def test_fuse_linear_qta8a_qga4w(self):
184183
_check_ir_validity=False,
185184
)
186185

187-
program = torch.export.export_for_training(
186+
program = torch.export.export(
188187
quantized_model, sample_inputs, strict=True
189188
).module()
190189

backends/vulkan/test/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
_load_for_executorch_from_buffer,
3636
)
3737
from executorch.extension.pytree import tree_flatten
38-
from torch.export import export, export_for_training
38+
from torch.export import export, export
3939

4040
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4141

@@ -53,7 +53,7 @@ def get_exported_graph(
5353
dynamic_shapes=None,
5454
qmode=QuantizationMode.NONE,
5555
) -> torch.fx.GraphModule:
56-
export_training_graph = export_for_training(
56+
export_training_graph = export(
5757
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
5858
).module()
5959

@@ -590,9 +590,7 @@ def op_ablation_test( # noqa: C901
590590
logger.info("Starting fast binary search operator ablation test...")
591591

592592
# Step 1: Export model to get edge_program and extract operators
593-
export_training_graph = export_for_training(
594-
model, sample_inputs, strict=True
595-
).module()
593+
export_training_graph = export(model, sample_inputs, strict=True).module()
596594
program = export(
597595
export_training_graph,
598596
sample_inputs,

0 commit comments

Comments
 (0)