Skip to content

Commit 5fde193

Browse files
committed
revert change
1 parent 19c5aa1 commit 5fde193

File tree

3 files changed

+57
-23
lines changed

3 files changed

+57
-23
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
11891189
calibration_seq_length=llm_config.quantization.calibration_seq_length,
11901190
calibration_data=llm_config.quantization.calibration_data,
11911191
tokenizer_path=llm_config.base.tokenizer_path,
1192+
use_legacy_export=llm_config.backend.qnn.enabled,
11921193
save_exported_program=llm_config.export.export_only,
11931194
verbose=llm_config.debug.verbose,
11941195
metadata=_load_llama_model_metadata(

extension/llm/export/builder.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010

1111
# pyre-unsafe
1212

13+
import contextlib
1314
import logging
1415
from enum import Enum
1516
from typing import Any, Callable, Dict, List, Optional, Tuple
17+
from unittest.mock import patch
1618

1719
import torch
1820
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -94,6 +96,7 @@ def __init__(
9496
verbose: bool = False,
9597
metadata: Optional[dict] = None,
9698
dynamic_shapes: Optional[Any] = None,
99+
use_legacy_export: bool = False,
97100
save_exported_program: bool = False,
98101
):
99102
# Store necessary constructor arguments.
@@ -114,6 +117,7 @@ def __init__(
114117
self.verbose = verbose
115118
self.metadata = metadata
116119
self.dynamic_shapes = dynamic_shapes
120+
self.use_legacy_export = use_legacy_export
117121
self.save_exported_program = save_exported_program
118122

119123
# Note: treat this as the source of truth for the result of
@@ -225,20 +229,39 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
225229
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
226230
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
227231
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
228-
if module:
229-
logging.info("Re-exporting with:")
232+
if self.use_legacy_export:
233+
# TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
234+
# See issue: https://github.com/pytorch/executorch/issues/7373
235+
236+
with patch.object(
237+
torch._utils_internal,
238+
"export_training_ir_rollout_check",
239+
return_value=False,
240+
):
241+
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
242+
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
243+
exported_module = torch.export.export(
244+
self.model if not module else module,
245+
self.example_inputs,
246+
self.example_kwarg_inputs,
247+
dynamic_shapes=dynamic_shape,
248+
strict=True,
249+
)
230250
else:
231-
logging.info("Exporting with:")
232-
logging.info(f"inputs: {self.example_inputs}")
233-
logging.info(f"kwargs: {self.example_kwarg_inputs}")
234-
logging.info(f"dynamic shapes: {dynamic_shape}")
235-
exported_module = export_for_training(
236-
self.model if not module else module,
237-
self.example_inputs,
238-
kwargs=self.example_kwarg_inputs,
239-
dynamic_shapes=dynamic_shape,
240-
strict=True,
241-
)
251+
if module:
252+
logging.info("Re-exporting with:")
253+
else:
254+
logging.info("Exporting with:")
255+
logging.info(f"inputs: {self.example_inputs}")
256+
logging.info(f"kwargs: {self.example_kwarg_inputs}")
257+
logging.info(f"dynamic shapes: {dynamic_shape}")
258+
exported_module = export_for_training(
259+
self.model if not module else module,
260+
self.example_inputs,
261+
kwargs=self.example_kwarg_inputs,
262+
dynamic_shapes=dynamic_shape,
263+
strict=True,
264+
)
242265
return exported_module
243266

244267
def export(self) -> "LLMEdgeManager":
@@ -423,15 +446,24 @@ def export_to_edge(self) -> "LLMEdgeManager":
423446
# Run export() if it didn't run
424447
self.export()
425448

426-
self.edge_manager = export_to_edge(
427-
self.pre_autograd_graph_module, # pyre-fixme[6]
428-
self.example_inputs,
429-
example_kwarg_inputs=self.example_kwarg_inputs,
430-
dynamic_shapes=dynamic_shape,
431-
edge_constant_methods=self.metadata,
432-
edge_compile_config=edge_config,
433-
verbose=self.verbose,
434-
)
449+
override_export_behaviour = contextlib.nullcontext()
450+
if self.use_legacy_export:
451+
override_export_behaviour = patch.object(
452+
torch._utils_internal,
453+
"export_training_ir_rollout_check",
454+
return_value=False,
455+
)
456+
457+
with override_export_behaviour:
458+
self.edge_manager = export_to_edge(
459+
self.pre_autograd_graph_module, # pyre-fixme[6]
460+
self.example_inputs,
461+
example_kwarg_inputs=self.example_kwarg_inputs,
462+
dynamic_shapes=dynamic_shape,
463+
edge_constant_methods=self.metadata,
464+
edge_compile_config=edge_config,
465+
verbose=self.verbose,
466+
)
435467
return self
436468

437469
def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":

extension/llm/export/partitioner_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,6 @@ def get_qnn_partitioner(
216216
),
217217
skip_node_id_set={},
218218
skip_node_op_set=skip_node_op_set,
219-
skip_mutable_buffer=False,
219+
# TODO: if deprecated legacy export, skip_mutable_buffer can be set False
220+
skip_mutable_buffer=True,
220221
)

0 commit comments

Comments
 (0)