Skip to content

Commit 4b831ff

Browse files
tugsbayasgalanStrycekSimon
authored andcommitted
Replace export_for_training with export
Differential Revision: D81936329 Pull Request resolved: pytorch#14073
1 parent a8e932b commit 4b831ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+123
-184
lines changed

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.export import export_for_training
18+
from torch.export import export
1919
from torchao.quantization.pt2e.quantize_pt2e import (
2020
convert_pt2e,
2121
prepare_pt2e,
@@ -32,9 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(
36-
model, example_inputs, strict=True
37-
).module()
35+
pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module()
3836

3937
quantization_config = LinearQuantizerConfig.from_dict(
4038
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def lower_module_and_test_output(
206206

207207
expected_output = model(*sample_inputs)
208208

209-
model = torch.export.export_for_training(
209+
model = torch.export.export(
210210
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

backends/cortex_m/test/test_quantize_op_fusion_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_node_args,
2424
)
2525
from executorch.exir.dialects._ops import ops as exir_ops
26-
from torch.export import export, export_for_training
26+
from torch.export import export
2727
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2828

2929

@@ -42,9 +42,7 @@ def _prepare_quantized_model(self, model_class):
4242
model = model_class()
4343

4444
# Export and quantize
45-
exported_model = export_for_training(
46-
model.eval(), self.example_inputs, strict=True
47-
).module()
45+
exported_model = export(model.eval(), self.example_inputs, strict=True).module()
4846
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
4947
quantized_model = convert_pt2e(prepared_model)
5048

@@ -242,9 +240,7 @@ def forward(self, x, y):
242240
inputs = (torch.randn(shape), torch.randn(shape))
243241

244242
model = SingleAddModel()
245-
exported_model = export_for_training(
246-
model.eval(), inputs, strict=True
247-
).module()
243+
exported_model = export(model.eval(), inputs, strict=True).module()
248244
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
249245
quantized_model = convert_pt2e(prepared_model)
250246

backends/example/test_example_delegate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def get_example_inputs():
4646
)
4747

4848
m = model.eval()
49-
m = torch.export.export_for_training(
50-
m, copy.deepcopy(example_inputs), strict=True
51-
).module()
49+
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
5250
# print("original model:", m)
5351
quantizer = ExampleQuantizer()
5452
# quantizer = XNNPACKQuantizer()
@@ -84,9 +82,7 @@ def test_delegate_mobilenet_v2(self):
8482
)
8583

8684
m = model.eval()
87-
m = torch.export.export_for_training(
88-
m, copy.deepcopy(example_inputs), strict=True
89-
).module()
85+
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
9086
quantizer = ExampleQuantizer()
9187

9288
m = prepare_pt2e(m, quantizer)

backends/mediatek/quantizer/annotator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch._ops import OpOverload
1111
from torch._subclasses import FakeTensor
1212

13-
from torch.export import export_for_training
13+
from torch.export import export
1414
from torch.fx import Graph, Node
1515
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
1616
SubgraphMatcherWithNameNodeMap,
@@ -158,9 +158,7 @@ def forward(self, x):
158158
return norm, {}
159159

160160
for pattern_cls in (ExecuTorchPattern, MTKPattern):
161-
pattern_gm = export_for_training(
162-
pattern_cls(), (torch.randn(3, 3),), strict=True
163-
).module()
161+
pattern_gm = export(pattern_cls(), (torch.randn(3, 3),), strict=True).module()
164162
matcher = SubgraphMatcherWithNameNodeMap(
165163
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
166164
)

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def get_prepared_qat_module(
576576
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
577577
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
578578
) -> torch.fx.GraphModule:
579-
m = torch.export.export_for_training(module, inputs, strict=True).module()
579+
m = torch.export.export(module, inputs, strict=True).module()
580580

581581
quantizer = make_quantizer(
582582
quant_dtype=quant_dtype,

backends/test/harness/stages/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
DuplicateDynamicQuantChainPass,
88
)
99

10-
from torch.export import export_for_training
10+
from torch.export import export
1111

1212
from torchao.quantization.pt2e.quantize_pt2e import (
1313
convert_pt2e,
@@ -47,7 +47,7 @@ def run(
4747
assert inputs is not None
4848
if self.is_qat:
4949
artifact.train()
50-
captured_graph = export_for_training(artifact, inputs, strict=True).module()
50+
captured_graph = export(artifact, inputs, strict=True).module()
5151

5252
assert isinstance(captured_graph, torch.fx.GraphModule)
5353

backends/test/suite/runner.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import re
66
import time
77
import unittest
8-
import warnings
98

109
from datetime import timedelta
1110
from typing import Any
@@ -283,10 +282,6 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
283282
def runner_main():
284283
args = parse_args()
285284

286-
# Suppress deprecation warnings for export_for_training, as it generates a
287-
# lot of log spam. We don't really need the warning here.
288-
warnings.simplefilter("ignore", category=FutureWarning)
289-
290285
seed = args.seed or random.randint(0, 100_000_000)
291286
print(f"Running with seed {seed}.")
292287

backends/transforms/test/test_duplicate_dynamic_quant_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _test_duplicate_chain(
5858

5959
# program capture
6060
m = copy.deepcopy(m_eager)
61-
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
61+
m = torch.export.export(m, example_inputs, strict=True).module()
6262

6363
m = prepare_pt2e(m, quantizer)
6464
# Calibrate

backends/vulkan/test/test_vulkan_passes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def quantize_and_lower_module(
5858
_check_ir_validity=False,
5959
)
6060

61-
program = torch.export.export_for_training(
61+
program = torch.export.export(
6262
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
6363
).module()
6464

@@ -95,7 +95,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) ->
9595

9696

9797
class TestVulkanPasses(unittest.TestCase):
98-
9998
def test_fuse_int8pack_mm(self):
10099
K = 256
101100
N = 256
@@ -184,7 +183,7 @@ def test_fuse_linear_qta8a_qga4w(self):
184183
_check_ir_validity=False,
185184
)
186185

187-
program = torch.export.export_for_training(
186+
program = torch.export.export(
188187
quantized_model, sample_inputs, strict=True
189188
).module()
190189

0 commit comments

Comments
 (0)