Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(
dynamic_shapes: Optional[Any] = None,
):
self.model = model
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
self.exported_program: Optional[ExportedProgram] = None
# Self.exported_program's pre-autograd graph module, for running
# transform passes on the graph prior to torch.export().
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
self.modelname = modelname
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -184,7 +186,21 @@ def _get_edge_config(self) -> EdgeCompileConfig:
)
return edge_config

def export(self) -> "LLMEdgeManager":
def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager":
"""
Exports the model pre-autograd. This is not a full export, since it uses
torch.export_for_training() to keep autograd-safe ops from getting decomposed.
The full torch.export() if called later on during to_edge() or
to_edge_transform_and_lower().

The optional `module` argument is included so that the user can re-export
an already-exported module's ExportedProgram's graph module, to persiste
the changes into a new ExportedProgram.

Args:
module (Optional[torch.nn.Module]): module to export.

"""
dynamic_shape = self._get_dynamic_shape()
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
Expand All @@ -201,25 +217,30 @@ def export(self) -> "LLMEdgeManager":
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
exported_module = torch.export.export(
self.model,
self.model if not module else module,
self.example_inputs,
self.example_kwarg_inputs,
dynamic_shapes=dynamic_shape,
strict=True,
)
else:
logging.info("Exporting with:")
if module:
logging.info("Re-exporting with:")
else:
logging.info("Exporting with:")
logging.info(f"inputs: {self.example_inputs}")
logging.info(f"kwargs: {self.example_kwarg_inputs}")
logging.info(f"dynamic shapes: {dynamic_shape}")
exported_module = export_for_training(
self.model,
self.model if not module else module,
self.example_inputs,
kwargs=self.example_kwarg_inputs,
dynamic_shapes=dynamic_shape,
)
# `Module`.
self.pre_autograd_exported_program = exported_module
self.exported_program = exported_module
# Need to store the graph module to record transformation passes.
# Persisting those changes back to the ExportedProgram will require
# an additional export().
self.pre_autograd_graph_module = exported_module.module()
if hasattr(self.args, "export_only") and self.args.export_only:
torch.export.save(exported_module, self.args.output_name)
Expand Down Expand Up @@ -382,7 +403,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.pre_autograd_graph_module is None:
if self.exported_program is None:
# Run export() if it didn't run
self.export()

Expand All @@ -394,9 +415,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
return_value=False,
)

# Prior to export, persist the changes to the pre autograd
# graph module back to the source-of-truth ExportedProgram.
self.export(self.pre_autograd_graph_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep exported_program up-to-date. Thus shouldnt do this here but rather wherever we extract graph_module and apply any transformations. Thus we should not keep self.pre_autograd_graph_module at all. Only source of truth would be exported_program

with override_export_behaviour:
self.edge_manager = export_to_edge(
self.pre_autograd_graph_module, # pyre-fixme[6]
self.exported_program.module(), # pyre-fixme[6]
self.example_inputs,
example_kwarg_inputs=self.example_kwarg_inputs,
dynamic_shapes=dynamic_shape,
Expand Down Expand Up @@ -441,9 +465,14 @@ def to_edge_transform_and_lower(
) -> "LLMEdgeManager":
if partitioners is None:
logging.info("No partitioner provided, skipping backend lowering...")

# Prior to export, persist the changes to the pre autograd
# graph module back to the source-of-truth ExportedProgram.
self.export(self.pre_autograd_graph_module)

edge_config = self._get_edge_config()
self.edge_manager = to_edge_transform_and_lower(
self.pre_autograd_exported_program,
self.exported_program,
partitioner=partitioners,
compile_config=edge_config,
constant_methods=self.metadata,
Expand Down
Loading