diff --git a/docs/source/tutorials_source/pt2e_quant_ptq.rst b/docs/source/tutorials_source/pt2e_quant_ptq.rst index 0b483697e3..86906f2c34 100644 --- a/docs/source/tutorials_source/pt2e_quant_ptq.rst +++ b/docs/source/tutorials_source/pt2e_quant_ptq.rst @@ -362,7 +362,7 @@ Here is how you can use ``torch.export`` to export the model: {0: torch.export.Dim("dim")} if i == 0 else None for i in range(len(example_inputs)) ) - exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module() + exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module() # for pytorch 2.5 and before # dynamic_shape API may vary as well @@ -501,7 +501,7 @@ Now we can compare the size and model accuracy with baseline model. # Quantized model size and accuracy print("Size of model after quantization") # export again to remove unused weights - quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module() + quantized_model = torch.export.export(quantized_model, example_inputs).module() print_size_of_model(quantized_model) top1, top5 = evaluate(quantized_model, criterion, data_loader_test) diff --git a/docs/source/tutorials_source/pt2e_quant_qat.rst b/docs/source/tutorials_source/pt2e_quant_qat.rst index cba870c668..d8eb013d70 100644 --- a/docs/source/tutorials_source/pt2e_quant_qat.rst +++ b/docs/source/tutorials_source/pt2e_quant_qat.rst @@ -13,7 +13,6 @@ to the post training quantization (PTQ) flow for the most part: .. code:: python import torch - from torch._export import capture_pre_autograd_graph from torchao.quantization.pt2e.quantize_pt2e import ( prepare_qat_pt2e, convert_pt2e, @@ -434,7 +433,6 @@ prepared. For example: .. code:: python - from torch._export import capture_pre_autograd_graph from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, @@ -443,7 +441,7 @@ prepared. For example: example_inputs = (torch.rand(2, 3, 224, 224),) float_model = resnet18(pretrained=False) - exported_model = capture_pre_autograd_graph(float_model, example_inputs) + exported_model = torch.export.export(float_model, example_inputs).module() quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) prepared_model = prepare_qat_pt2e(exported_model, quantizer) diff --git a/docs/source/tutorials_source/pt2e_quant_x86_inductor.rst b/docs/source/tutorials_source/pt2e_quant_x86_inductor.rst index e4faec469f..5cbe96a67a 100644 --- a/docs/source/tutorials_source/pt2e_quant_x86_inductor.rst +++ b/docs/source/tutorials_source/pt2e_quant_x86_inductor.rst @@ -105,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t exported_model = export( model, example_inputs - ) + ).module() Next, we will have the FX Module to be quantized. @@ -243,12 +243,10 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow: .. code:: python import torch - from torch._export import capture_pre_autograd_graph from torchao.quantization.pt2e.quantize_pt2e import ( prepare_qat_pt2e, convert_pt2e, ) - from torch.export import export import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer @@ -264,9 +262,7 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow: m = M() # Step 1. program capture - # NOTE: this API will be updated to torch.export API in the future, but the captured - # result shoud mostly stay the same - exported_model = export(m, example_inputs) + exported_model = torch.export.export(m, example_inputs).module() # we get a model with aten ops # Step 2. quantization-aware training diff --git a/examples/sam2_amg_server/compile_export_utils.py b/examples/sam2_amg_server/compile_export_utils.py index 3797e60af6..32667748a5 100644 --- a/examples/sam2_amg_server/compile_export_utils.py +++ b/examples/sam2_amg_server/compile_export_utils.py @@ -118,10 +118,7 @@ def aot_compile( "max_autotune": True, "triton.cudagraphs": True, } - - from torch.export import export_for_training - - exported = export_for_training(fn, sample_args, sample_kwargs, strict=True) + exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True) exported.run_decompositions() output_path = torch._inductor.aoti_compile_and_package( exported, diff --git a/examples/sam2_vos_example/compile_export_utils.py b/examples/sam2_vos_example/compile_export_utils.py index 73551db675..3bb5add5a4 100644 --- a/examples/sam2_vos_example/compile_export_utils.py +++ b/examples/sam2_vos_example/compile_export_utils.py @@ -81,10 +81,7 @@ def aot_compile( "max_autotune": True, "triton.cudagraphs": True, } - - from torch.export import export_for_training - - exported = export_for_training(fn, sample_args, sample_kwargs, strict=True) + exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True) exported.run_decompositions() output_path = torch._inductor.aoti_compile_and_package( exported, diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index aa9eccc903..a1d87dbc91 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -242,10 +242,7 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - m = torch.export.texport_for_training( - m, - example_inputs, - ).module() + m = torch.export.export(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 455a51061b..afa6cfff99 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1953,9 +1953,7 @@ def forward(self, x): # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() + model = torch.export.export(model, example_inputs, strict=True).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int4w_api: diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index ec4f928df2..ceb9e840c1 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -157,8 +157,6 @@ def _check_common( ) @config.patch({"freezing": True}) def _test_sdpa_int8_rewriter(self): - from torch.export import export_for_training - import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( @@ -199,11 +197,7 @@ def _test_sdpa_int8_rewriter(self): quantizer.set_function_type_qconfig( torch.matmul, quantizer.get_global_quantization_config() ) - export_model = export_for_training( - mod, - inputs, - strict=True, - ).module() + export_model = torch.export.export(mod, inputs, strict=True).module() prepare_model = prepare_pt2e(export_model, quantizer) prepare_model(*inputs) convert_model = convert_pt2e(prepare_model) diff --git a/test/quantization/pt2e/test_arm_inductor_quantizer.py b/test/quantization/pt2e/test_arm_inductor_quantizer.py index 4c3b397382..42a826e43a 100644 --- a/test/quantization/pt2e/test_arm_inductor_quantizer.py +++ b/test/quantization/pt2e/test_arm_inductor_quantizer.py @@ -14,7 +14,6 @@ import torch import torch.nn as nn -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -315,10 +314,7 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = torch.export.export(m, example_inputs).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -576,7 +572,7 @@ def _test_linear_unary_helper( Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer. """ use_bias_list = [True, False] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of export inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] @@ -716,7 +712,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of export inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = ArmInductorQuantizer().set_global( @@ -1078,7 +1074,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = torch.export.export(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 8430f605e1..dcfdfd4553 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -110,7 +110,7 @@ def _test_duplicate_dq( # program capture m = copy.deepcopy(m_eager) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index c9fa3960ee..cb54eba66d 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -107,7 +107,7 @@ def _test_metadata_porting( # program capture m = copy.deepcopy(m_eager) - m = torch.export.export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 07d884e45f..a050f476ef 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -20,11 +20,8 @@ from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 -if TORCH_VERSION_AT_LEAST_2_8: - from torch.export import export_for_training - # Increase cache size limit to avoid FailOnRecompileLimitHit error when running multiple tests -# that use export_for_training, which causes many dynamo recompilations +# that use torch.export.export, which causes many dynamo recompilations if TORCH_VERSION_AT_LEAST_2_8: torch._dynamo.config.cache_size_limit = 128 @@ -37,7 +34,7 @@ class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase): def test_simple(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() self._assert_each_node_has_from_node_source(m) from_node_source_map = self._extract_from_node_source(m) @@ -50,7 +47,7 @@ def test_simple(self): def test_control_flow(self): m = TestHelperModules.ControlFlow() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() self._assert_each_node_has_from_node_source(m) @@ -93,13 +90,13 @@ def test_deepcopy_preserve_handle(self): def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() self._assert_each_node_has_from_node_source(m) from_node_source_map_ref = self._extract_from_node_source(m) - ep_reexport = export_for_training(m, example_inputs, strict=True) + ep_reexport = torch.export.export(m, example_inputs, strict=True) m_reexport = ep_reexport.module() self._assert_each_node_has_from_node_source(m_reexport) @@ -110,7 +107,7 @@ def test_re_export_preserve_handle(self): def test_run_decompositions_same_handle_id(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() self._assert_each_node_has_from_node_source(m) @@ -136,7 +133,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self): for m in test_models: example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() self._assert_each_node_has_from_node_source(m) @@ -161,7 +158,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self): def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() m_logger = prepare_for_propagation_comparison(m) ref = m(*example_inputs) @@ -177,7 +174,7 @@ def test_prepare_for_propagation_comparison(self): def test_added_node_gets_unique_id(self) -> None: m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) ref_from_node_source = self._extract_from_node_source(ep.module()) ref_counter = Counter(ref_from_node_source.values()) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 0c1a1f23c9..3f891550c5 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -790,7 +790,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, BackendAQuantizer()) # make sure the two observers for input are shared conv_output_obs = [] @@ -850,7 +850,7 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer): ) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) # make sure the two input observers and output are shared @@ -1169,7 +1169,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: ) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = BackendAQuantizer() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -1321,7 +1321,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = export_for_training(m, (example_inputs,), strict=True).module() + m = torch.export.export(m, (example_inputs,), strict=True).module() with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1329,7 +1329,7 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): # resetting dynamo cache torch._dynamo.reset() - m = export_for_training( + m = torch.export.export( m, example_inputs, ).module() @@ -1478,7 +1478,7 @@ def forward(self, x): quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() weight_meta = None for n in m.graph.nodes: if ( @@ -1566,7 +1566,7 @@ def forward(self, x): m = M().eval() quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1617,7 +1617,7 @@ def forward(self, x, y, z): torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3), ) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1872,7 +1872,7 @@ def forward(self, x): example_inputs = (torch.randn(1),) m = M().train() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() if inplace: target = torch.ops.aten.dropout_.default else: @@ -1934,7 +1934,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() # Assert that batch norm op exists and is in train mode bn_node = self._get_node(m, bn_train_op) @@ -1965,7 +1965,7 @@ def test_disallow_eval_train(self): m.train() # After export: this is not OK - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -2008,7 +2008,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): targets = [n.target for n in m.graph.nodes] @@ -2074,7 +2074,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() torchao.quantization.pt2e.allow_exported_model_train_eval(m) # Mock m.recompile() to count how many times it's been called @@ -2106,7 +2106,7 @@ def _fake_recompile(): def test_model_is_exported(self): m = TestHelperModules.ConvWithBNRelu(relu=True) example_inputs = (torch.rand(3, 3, 5, 5),) - exported_gm = export_for_training(m, example_inputs, strict=True).module() + exported_gm = torch.export.export(m, example_inputs, strict=True).module() fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) self.assertTrue( torchao.quantization.pt2e.export_utils.model_is_exported(exported_gm) @@ -2124,7 +2124,7 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_qat=True) ) - m.conv_bn_relu = export_for_training( + m.conv_bn_relu = torch.export.export( m.conv_bn_relu, example_inputs, strict=True ).module() m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) @@ -2134,7 +2134,7 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_module_type( torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) ) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -2297,7 +2297,7 @@ def test_speed(self): def dynamic_quantize_pt2e(model, example_inputs): torch._dynamo.reset() - model = export_for_training(model, example_inputs, strict=True).module() + model = torch.export.export(model, example_inputs, strict=True).module() # Per channel quantization for weight # Dynamic quantization for activation # Please read a detail: https://fburl.com/code/30zds51q @@ -2704,7 +2704,7 @@ def forward(self, x): example_inputs = (torch.randn(1, 3, 5, 5),) m = M() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(), ) @@ -2786,7 +2786,7 @@ def prepare_obs_or_fq_callback( edge_or_node_to_obs_or_fq[x] = new_observer example_inputs = (torch.rand(1, 32, 16, 16),) - gm = export_for_training(Model().eval(), example_inputs, strict=True).module() + gm = torch.export.export(Model().eval(), example_inputs, strict=True).module() gm = prepare_pt2e(gm, BackendAQuantizer()) gm = convert_pt2e(gm) for n in gm.graph.nodes: @@ -2813,7 +2813,7 @@ def check_nn_module(node): "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] ) - m.conv_bn_relu = export_for_training( + m.conv_bn_relu = torch.export.export( m.conv_bn_relu, example_inputs, strict=True ).module() for node in m.conv_bn_relu.graph.nodes: @@ -2898,7 +2898,7 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool: quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) quantizer.set_example_inputs(example_inputs) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() # Check that the model has in-place ops self.assertTrue(has_inplace_ops(m)) m = prepare_pt2e(m, quantizer) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index e0a51453a9..5f82398811 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -149,7 +149,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper( is_per_channel=is_per_channel, is_qat=True ) ) - model_pt2e = export_for_training( + model_pt2e = torch.export.export( model_pt2e, example_inputs, strict=True ).module() model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) @@ -248,7 +248,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( quantizer.set_global( get_symmetric_quantization_config(is_per_channel, is_qat=True) ) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -638,7 +638,7 @@ def forward(self, x): m = M(self.conv_class, self.bn_class, backbone) quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) @@ -696,7 +696,7 @@ def get_source_fn(node: torch.fx.Node): def test_qat_conv_bn_bias_derived_qspec(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -743,7 +743,7 @@ def test_qat_conv_bn_bias_derived_qspec(self): def test_qat_per_channel_weight_custom_dtype(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = ConvBnInt32WeightQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -797,7 +797,7 @@ def test_qat_conv_transpose_bn_relu(self): def test_qat_conv_bn_per_channel_weight_bias(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -854,7 +854,7 @@ def test_fold_bn_erases_bn_node(self): it into conv in `convert_pt2e` even in train mode. """ m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = export_for_training(m, self.example_inputs, strict=True).module() + m = torch.export.export(m, self.example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True), @@ -1106,7 +1106,7 @@ def _prepare_qat_linears(self, model): in_channels = child.linear1.weight.size(1) example_input = (torch.rand((1, in_channels)),) - traced_child = export_for_training( + traced_child = torch.export.export( child, example_input, strict=True ).module() quantizer = XNNPACKQuantizer() @@ -1139,7 +1139,7 @@ def test_mixing_qat_ptq(self): self._convert_qat_linears(model) model(*example_inputs) - model_pt2e = export_for_training(model, example_inputs, strict=True).module() + model_pt2e = torch.export.export(model, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_module_type(torch.nn.Linear, None) diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index abe79a08e3..6b17162495 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -46,7 +46,7 @@ def _test_representation( ) -> torch.nn.Module: # resetting dynamo cache torch._dynamo.reset() - model = export_for_training(model, example_inputs, strict=True).module() + model = torch.export.export(model, example_inputs, strict=True).module() model_copy = copy.deepcopy(model) model = prepare_pt2e(model, quantizer) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 42439552c6..099b77e0db 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -16,7 +16,6 @@ from torch._inductor import config from torch._inductor.test_case import TestCase, run_tests from torch._inductor.utils import run_and_get_code -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( skipIfNoDynamoSupport, skipIfNoONEDNN, @@ -107,7 +106,7 @@ def _generate_qdq_quantized_model( ): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: - export_model = export_for_training(mod, inputs, strict=True).module() + export_model = torch.export.export(mod, inputs, strict=True).module() quantizer = ( quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) ) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 9dc7da3571..3b09d5c8e8 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -676,7 +676,7 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -1430,7 +1430,7 @@ def _test_linear_unary_helper( Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] @@ -1570,7 +1570,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -1674,7 +1674,7 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_add_list = [False] is_qat_list = [False, True] is_dynamic_list = [False, True] @@ -1743,9 +1743,9 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op and relu as unary post op are supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_add_list = [False] - # TODO test for inplace relu after refactoring of export_for_training + # TODO test for inplace relu after refactoring of torch.export.export inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -2353,7 +2353,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/torchao/quantization/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e/_numeric_debugger.py index 5211e0f340..df01d02f99 100644 --- a/torchao/quantization/pt2e/_numeric_debugger.py +++ b/torchao/quantization/pt2e/_numeric_debugger.py @@ -51,7 +51,7 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None: Here's an example of using debug handle quantize flow:: - ep = export_for_training(eager_model, example_inputs) + ep = torch.export.export(eager_model, example_inputs) generate_numeric_debug_handle(ep) m = ep.module() diff --git a/torchao/quantization/pt2e/lowering.py b/torchao/quantization/pt2e/lowering.py index 76dad800cd..c0b4a3538b 100644 --- a/torchao/quantization/pt2e/lowering.py +++ b/torchao/quantization/pt2e/lowering.py @@ -55,7 +55,7 @@ def _node_replace(m): # type: ignore[no-untyped-def] m.recompile() lowered_model = ( - torch.export.export_for_training(model, example_inputs, strict=True) + torch.export.export(model, example_inputs, strict=True) .run_decompositions(_post_autograd_decomp_table()) .module() ) diff --git a/torchao/quantization/pt2e/quantize_pt2e.py b/torchao/quantization/pt2e/quantize_pt2e.py index e58dc8e3ee..1975642dfd 100644 --- a/torchao/quantization/pt2e/quantize_pt2e.py +++ b/torchao/quantization/pt2e/quantize_pt2e.py @@ -46,7 +46,7 @@ def prepare_pt2e( """Prepare a model for post training quantization Args: - * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export` API. * `quantizer`: A backend specific quantizer that conveys how user want the model to be quantized. Tutorial for how to write a quantizer can be found here: https://pytorch.org/tutorials/prototype/pt2e_quantizer.html @@ -84,7 +84,7 @@ def calibrate(model, data_loader): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured # result shoud mostly stay the same - m = torch.export.export_for_training(m, *example_inputs).module() + m = torch.export.export(m, *example_inputs).module() # we get a model with aten ops # Step 2. quantization @@ -169,7 +169,7 @@ def train_loop(model, train_data): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured # result shoud mostly stay the same - m = torch.export.export_for_training(m, *example_inputs).module() + m = torch.export.export(m, *example_inputs).module() # we get a model with aten ops # Step 2. quantization diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 41a26b62eb..486f82c6a7 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -815,7 +815,7 @@ def _get_aten_graph_module_for_pattern( [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs] ) - aten_pattern = torch.export.export_for_training( + aten_pattern = torch.export.export( pattern, # type: ignore[arg-type] example_inputs, kwargs, diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index a41d3f597f..74a0269018 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -72,7 +72,7 @@ def _test_quantizer( {0: torch.export.Dim("dim")} if i == 0 else None for i in range(len(example_inputs)) ) - m = export_for_training( + m = torch.export.export( m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, @@ -113,7 +113,7 @@ def _test_quantizer( m_fx = _convert_to_reference_decomposed_fx( m_fx, backend_config=backend_config ) - m_fx = export_for_training( + m_fx = torch.export.export( m_fx, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, diff --git a/torchao/utils.py b/torchao/utils.py index f72e60e3d1..a32166d556 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -203,7 +203,7 @@ def _the_op_that_needs_to_be_preserved(...) # after this, `_the_op_that_needs_to_be_preserved` will be preserved as # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after - # torch.export.export / torch._export.export_for_training + # torch.export.export """ from torch._inductor.decomposition import register_decomposition