1313import contextlib
1414import logging
1515from enum import Enum
16- from typing import Any , Callable , Dict , List , Optional
16+ from typing import Any , Callable , Dict , List , Optional , Tuple
1717from unittest .mock import patch
1818
1919import torch
@@ -81,14 +81,13 @@ class LLMEdgeManager:
8181
8282 def __init__ (
8383 self ,
84- model ,
85- modelname ,
86- max_seq_len ,
87- dtype ,
88- use_kv_cache ,
89- example_inputs ,
84+ model : torch . nn . Module ,
85+ modelname : str ,
86+ max_seq_len : int ,
87+ use_kv_cache : bool ,
88+ example_inputs : Tuple [ torch . Tensor , ...] ,
89+ dtype : Optional [ DType ] = None ,
9090 example_kwarg_inputs : Optional [Dict ] = None ,
91- args : Optional [Any ] = None ,
9291 enable_dynamic_shape : bool = False ,
9392 generate_full_logits : bool = False ,
9493 calibration_tasks : Optional [List [str ]] = None ,
@@ -99,36 +98,42 @@ def __init__(
9998 verbose : bool = False ,
10099 metadata : Optional [dict ] = None ,
101100 dynamic_shapes : Optional [Any ] = None ,
101+ use_legacy_export : bool = False ,
102+ save_exported_program : bool = False ,
102103 ):
104+ # Store necessary constructor arguments.
103105 self .model = model
104- # Note: treat this as the source of truth for the result of
105- # torch.export'ing a model. If the overall ExportedProgram is needed,
106- # make sure to re-export this graph module to persist any changes. See
107- # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
108- self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
109106 self .modelname = modelname
110107 self .max_seq_len = max_seq_len
111- self .dtype = dtype
108+ self .use_kv_cache = use_kv_cache
112109 self .example_inputs = example_inputs
110+ self .dtype = dtype
113111 self .example_kwarg_inputs = example_kwarg_inputs
114- self .use_kv_cache = use_kv_cache
115- self .generate_full_logits = generate_full_logits
116112 self .enable_dynamic_shape = enable_dynamic_shape
117- self .verbose = verbose
118- self .metadata = metadata
119- self .applied_source_transforms = []
120- self .edge_manager : Optional [EdgeProgramManager ] = None
121- self .export_program = None
122- self .output_dir = "."
123- self .dynamic_shapes = dynamic_shapes
124- self ._saved_pte_filename = None
125- self .args = args
113+ self .generate_full_logits = generate_full_logits
126114 self .calibration_tasks = calibration_tasks
127115 self .calibration_limit = calibration_limit
128116 self .calibration_seq_length = calibration_seq_length
129117 self .calibration_data = calibration_data
130118 self .tokenizer_path = tokenizer_path
131- self .canonical_passes = [RemoveRedundantTransposes ()]
119+ self .verbose = verbose
120+ self .metadata = metadata
121+ self .dynamic_shapes = dynamic_shapes
122+ self .use_legacy_export = use_legacy_export
123+ self .save_exported_program = save_exported_program
124+
125+ # Note: treat this as the source of truth for the result of
126+ # torch.export'ing a model. If the overall ExportedProgram is needed,
127+ # make sure to re-export this graph module to persist any changes. See
128+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
129+ self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
130+ self .edge_manager : Optional [EdgeProgramManager ] = None
131+ self .canonical_passes = [
132+ RemoveRedundantTransposes ()
133+ ] # Graph transformations optimizations.
134+ self .export_program = None # Final result of lowering to executorch.
135+ self .output_dir = "."
136+ self ._saved_pte_filename = None
132137
133138 def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
134139 """
@@ -167,10 +172,9 @@ def source_transform(
167172 """
168173 for transform in transforms :
169174 self .model = transform (self .model )
170- self .applied_source_transforms .extend (transforms )
171175
172176 if self .verbose :
173- logging .info (f"Applied source transforms: { self . applied_source_transforms } " )
177+ logging .info (f"Applied source transforms: { transforms } " )
174178 logging .info (f"Model after source transforms: { self .model } " )
175179 return self
176180
@@ -209,8 +213,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
209213 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
210214 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
211215 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
212- if hasattr ( self .args , "qnn" ) and self . args . qnn :
213- # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
216+ if self .use_legacy_export :
217+ # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
214218 # See issue: https://github.com/pytorch/executorch/issues/7373
215219
216220 with patch .object (
@@ -256,8 +260,12 @@ def export(self) -> "LLMEdgeManager":
256260 # Persisting those changes back to an ExportedProgram will require
257261 # an additional export().
258262 self .pre_autograd_graph_module = exported_module .module ()
259- if hasattr (self .args , "export_only" ) and self .args .export_only :
260- torch .export .save (exported_module , self .args .output_name )
263+ if self .save_exported_program :
264+ export_output = f"{ self .modelname } .pt2"
265+ logging .info (
266+ f"Saving torch.export()/export_for_training() result to { export_output } "
267+ )
268+ torch .export .save (exported_module , export_output )
261269 return self
262270
263271 def run_canonical_optimizations (self ):
@@ -421,7 +429,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
421429 self .export ()
422430
423431 override_export_behaviour = contextlib .nullcontext ()
424- if hasattr ( self .args , "qnn" ) and self . args . qnn :
432+ if self .use_legacy_export :
425433 override_export_behaviour = patch .object (
426434 torch ._utils_internal ,
427435 "export_training_ir_rollout_check" ,
0 commit comments