|
10 | 10 |
|
11 | 11 | # pyre-unsafe |
12 | 12 |
|
| 13 | +import contextlib |
13 | 14 | import logging |
14 | 15 | from enum import Enum |
15 | 16 | from typing import Any, Callable, Dict, List, Optional |
| 17 | +from unittest.mock import patch |
16 | 18 |
|
17 | 19 | import torch |
18 | 20 | from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( |
@@ -184,15 +186,23 @@ def export(self) -> "LLMEdgeManager": |
184 | 186 | # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) |
185 | 187 | with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
186 | 188 | if hasattr(self.args, "qnn") and self.args.qnn: |
187 | | - # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a |
188 | | - # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details |
189 | | - exported_module = torch.export.export( |
190 | | - self.model, |
191 | | - self.example_inputs, |
192 | | - self.example_kwarg_inputs, |
193 | | - dynamic_shapes=dynamic_shape, |
194 | | - strict=True, |
195 | | - ) |
| 189 | + # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR. |
| 190 | + # See issue: https://github.com/pytorch/executorch/issues/7373 |
| 191 | + |
| 192 | + with patch.object( |
| 193 | + torch._utils_internal, |
| 194 | + "export_training_ir_rollout_check", |
| 195 | + return_value=False, |
| 196 | + ): |
| 197 | + # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a |
| 198 | + # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details |
| 199 | + exported_module = torch.export.export( |
| 200 | + self.model, |
| 201 | + self.example_inputs, |
| 202 | + self.example_kwarg_inputs, |
| 203 | + dynamic_shapes=dynamic_shape, |
| 204 | + strict=True, |
| 205 | + ) |
196 | 206 | else: |
197 | 207 | logging.info("Exporting with:") |
198 | 208 | logging.info(f"inputs: {self.example_inputs}") |
@@ -354,15 +364,25 @@ def export_to_edge(self) -> "LLMEdgeManager": |
354 | 364 | if self.pre_autograd_graph_module is None: |
355 | 365 | # Run export() if it didn't run |
356 | 366 | self.export() |
357 | | - self.edge_manager = export_to_edge( |
358 | | - self.pre_autograd_graph_module, # pyre-fixme[6] |
359 | | - self.example_inputs, |
360 | | - example_kwarg_inputs=self.example_kwarg_inputs, |
361 | | - dynamic_shapes=dynamic_shape, |
362 | | - edge_constant_methods=self.metadata, |
363 | | - edge_compile_config=edge_config, |
364 | | - verbose=self.verbose, |
365 | | - ) |
| 367 | + |
| 368 | + override_export_behaviour = contextlib.nullcontext() |
| 369 | + if hasattr(self.args, "qnn") and self.args.qnn: |
| 370 | + override_export_behaviour = patch.object( |
| 371 | + torch._utils_internal, |
| 372 | + "export_training_ir_rollout_check", |
| 373 | + return_value=False, |
| 374 | + ) |
| 375 | + |
| 376 | + with override_export_behaviour: |
| 377 | + self.edge_manager = export_to_edge( |
| 378 | + self.pre_autograd_graph_module, # pyre-fixme[6] |
| 379 | + self.example_inputs, |
| 380 | + example_kwarg_inputs=self.example_kwarg_inputs, |
| 381 | + dynamic_shapes=dynamic_shape, |
| 382 | + edge_constant_methods=self.metadata, |
| 383 | + edge_compile_config=edge_config, |
| 384 | + verbose=self.verbose, |
| 385 | + ) |
366 | 386 | return self |
367 | 387 |
|
368 | 388 | def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": |
|
0 commit comments