From e2cc5b350685be54d93a1200e352cbcc71cf346a Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 7 Jul 2025 10:55:58 -0700 Subject: [PATCH] Remove the legacy export (#12218) Summary: As title, try to see if we can get rid of the legacy export. It should be fixed with https://github.com/pytorch/executorch/pull/11782 Differential Revision: D77761473 --- examples/models/llama/export_llama_lib.py | 1 - extension/llm/export/builder.py | 55 ++++++----------------- 2 files changed, 13 insertions(+), 43 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 43ae595f797..39f5f2ec0cd 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1216,7 +1216,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": calibration_seq_length=llm_config.quantization.calibration_seq_length, calibration_data=llm_config.quantization.calibration_data, tokenizer_path=llm_config.base.tokenizer_path, - use_legacy_export=llm_config.backend.qnn.enabled, save_exported_program=llm_config.export.export_only, verbose=llm_config.debug.verbose, metadata=_load_llama_model_metadata( diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 333a18cdf84..6db881c5274 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -14,7 +14,6 @@ import logging from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest.mock import patch import torch from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( @@ -96,7 +95,6 @@ def __init__( verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, - use_legacy_export: bool = False, save_exported_program: bool = False, ): # Store necessary constructor arguments. @@ -117,7 +115,6 @@ def __init__( self.verbose = verbose self.metadata = metadata self.dynamic_shapes = dynamic_shapes - self.use_legacy_export = use_legacy_export self.save_exported_program = save_exported_program # Note: treat this as the source of truth for the result of @@ -228,39 +225,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if self.use_legacy_export: - # TODO: for use cases such as qnn, which does not work with new, non-functional export IR. - # See issue: https://github.com/pytorch/executorch/issues/7373 - - with patch.object( - torch._utils_internal, - "export_training_ir_rollout_check", - return_value=False, - ): - # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a - # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details - exported_module = torch.export.export( - self.model if not module else module, - self.example_inputs, - self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - strict=True, - ) + if module: + logging.info("Re-exporting with:") else: - if module: - logging.info("Re-exporting with:") - else: - logging.info("Exporting with:") - logging.info(f"inputs: {self.example_inputs}") - logging.info(f"kwargs: {self.example_kwarg_inputs}") - logging.info(f"dynamic shapes: {dynamic_shape}") - exported_module = export_for_training( - self.model if not module else module, - self.example_inputs, - kwargs=self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - strict=True, - ) + logging.info("Exporting with:") + logging.info(f"inputs: {self.example_inputs}") + logging.info(f"kwargs: {self.example_kwarg_inputs}") + logging.info(f"dynamic shapes: {dynamic_shape}") + exported_module = export_for_training( + self.model if not module else module, + self.example_inputs, + kwargs=self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, + strict=True, + ) return exported_module def export(self) -> "LLMEdgeManager": @@ -446,13 +424,6 @@ def export_to_edge(self) -> "LLMEdgeManager": self.export() override_export_behaviour = contextlib.nullcontext() - if self.use_legacy_export: - override_export_behaviour = patch.object( - torch._utils_internal, - "export_training_ir_rollout_check", - return_value=False, - ) - with override_export_behaviour: self.edge_manager = export_to_edge( self.pre_autograd_graph_module, # pyre-fixme[6]