|
21 | 21 | DuplicateDynamicQuantChainPass, |
22 | 22 | ) |
23 | 23 | from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass |
| 24 | + |
24 | 25 | from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower |
25 | 26 | from executorch.exir.backend.partitioner import Partitioner |
26 | 27 |
|
|
33 | 34 | from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass |
34 | 35 |
|
35 | 36 | from executorch.extension.export_util.utils import export_to_edge, save_pte_program |
36 | | - |
37 | 37 | from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes |
38 | 38 | from executorch.extension.llm.tokenizer.utils import get_tokenizer |
| 39 | +from omegaconf import DictConfig |
39 | 40 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
40 | 41 | from torch.ao.quantization.quantizer import Quantizer |
41 | 42 | from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer |
@@ -87,7 +88,7 @@ def __init__( |
87 | 88 | use_kv_cache, |
88 | 89 | example_inputs, |
89 | 90 | example_kwarg_inputs: Optional[Dict] = None, |
90 | | - args: Optional[Any] = None, |
| 91 | + config: Optional[DictConfig] = None, |
91 | 92 | enable_dynamic_shape: bool = False, |
92 | 93 | generate_full_logits: bool = False, |
93 | 94 | calibration_tasks: Optional[List[str]] = None, |
@@ -121,7 +122,7 @@ def __init__( |
121 | 122 | self.output_dir = "." |
122 | 123 | self.dynamic_shapes = dynamic_shapes |
123 | 124 | self._saved_pte_filename = None |
124 | | - self.args = args |
| 125 | + self.config = config |
125 | 126 | self.calibration_tasks = calibration_tasks |
126 | 127 | self.calibration_limit = calibration_limit |
127 | 128 | self.calibration_seq_length = calibration_seq_length |
@@ -203,7 +204,7 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: |
203 | 204 | # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing |
204 | 205 | # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) |
205 | 206 | with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
206 | | - if self.args.backend.qnn.enabled: |
| 207 | + if self.config.backend.qnn.enabled: |
207 | 208 | # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR. |
208 | 209 | # See issue: https://github.com/pytorch/executorch/issues/7373 |
209 | 210 |
|
@@ -249,8 +250,8 @@ def export(self) -> "LLMEdgeManager": |
249 | 250 | # Persisting those changes back to an ExportedProgram will require |
250 | 251 | # an additional export(). |
251 | 252 | self.pre_autograd_graph_module = exported_module.module() |
252 | | - if self.args.export.export_only: |
253 | | - torch.export.save(exported_module, self.args.export.output_name) |
| 253 | + if self.config.export.export_only: |
| 254 | + torch.export.save(exported_module, self.config.export.output_name) |
254 | 255 | return self |
255 | 256 |
|
256 | 257 | def run_canonical_optimizations(self): |
@@ -414,7 +415,7 @@ def export_to_edge(self) -> "LLMEdgeManager": |
414 | 415 | self.export() |
415 | 416 |
|
416 | 417 | override_export_behaviour = contextlib.nullcontext() |
417 | | - if self.args.backend.qnn.enabled: |
| 418 | + if self.config.backend.qnn.enabled: |
418 | 419 | override_export_behaviour = patch.object( |
419 | 420 | torch._utils_internal, |
420 | 421 | "export_training_ir_rollout_check", |
|
0 commit comments