Skip to content

Commit dba2a56

Browse files
committed
Export lora weights to sep file
Differential Revision: [D83777195](https://our.internmc.facebook.com/intern/diff/D83777195/) ghstack-source-id: 313704499 Pull Request resolved: #14756
1 parent b100c95 commit dba2a56

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,11 +1089,18 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10891089

10901090
if llm_config.backend.xnnpack.enabled:
10911091
if llm_config.export.foundation_weights_file is not None:
1092-
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
1093-
llm_config.export.foundation_weights_file
1094-
if "lora" not in x.name
1095-
else None
1096-
)
1092+
if llm_config.export.lora_weights_file is not None:
1093+
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
1094+
llm_config.export.foundation_weights_file
1095+
if "lora" not in x.name
1096+
else None
1097+
)
1098+
else:
1099+
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
1100+
llm_config.export.foundation_weights_file
1101+
if "lora" not in x.name
1102+
else llm_config.export.lora_weights_file
1103+
)
10971104

10981105
from executorch.exir.passes.external_constants_pass import (
10991106
delegate_external_constants_pass_unlifted,

extension/llm/export/config/llm_config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,10 @@ class ExportConfig:
215215
so_library: Shared library to specify custom quantized operators.
216216
export_only: Whether to stop right after torch.export() and
217217
just save the exported .pt2 graph file.
218-
foundation_weights_file: configure the foundation weights of a model
219-
to be placed in a separate file, external to the PTE. Pass the
220-
intended file name here.
218+
foundation_weights_file: place the foundation weights of the model into
219+
a separate file, external to the PTE. Pass the file name here.
220+
lora_weights_file: place the lora weights of the model into a
221+
separate file, external to the PTE. Pass the file name here.
221222
"""
222223

223224
max_seq_length: int = 128
@@ -227,6 +228,7 @@ class ExportConfig:
227228
so_library: Optional[str] = None
228229
export_only: bool = False
229230
foundation_weights_file: Optional[str] = None
231+
lora_weights_file: Optional[str] = None
230232

231233
def __post_init__(self):
232234
if self.max_context_length < self.max_seq_length:
@@ -572,6 +574,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
572574
llm_config.export.export_only = args.export_only
573575
if hasattr(args, "foundation_weights_file"):
574576
llm_config.export.foundation_weights_file = args.foundation_weights_file
577+
if hasattr(args, "lora_weights_file"):
578+
llm_config.export.lora_weights_file = args.lora_weights_file
575579

576580
# QuantizationConfig
577581
if hasattr(args, "quantization_mode"):

0 commit comments

Comments
 (0)