Skip to content

Commit 7ceacde

Browse files
committed
Update on "Save foundation weights separately"
This diff: 1. Introduces SerializationConfig to llm_config. Currently, this allows user to save the foundation weights in a separate file; majorly useful for lora case. 2. Adds a pass to tag foundation (non-lora) weights. This is at the top-level (export_llama_lib). The tags are preserved through run_decomps/other passes, and do not affect functionality. 3. Tags are read when placing constants into the named_data_store. 4. Tagged weights are serialized to a separate file. Notes 1. Adding tags to node.meta['custom']['blah'] means that they will not be discarded by run_decompositions 2. Adding tags to the lifted model (ep.graph_module) requires the EP to check is_param_node for xnnpack constants. Instead, add tags to the unlifted model (ep.module()), so we do not need to go through a re-export to get the EP. 3. Not an issue for this diff as llama doesn't have any higher order ops. Adding tags to models with higher-order ops is problematic due to nested submodules. Differential Revision: [D79181064](https://our.internmc.facebook.com/intern/diff/D79181064/) [ghstack-poisoned]
2 parents 36be399 + d45f5aa commit 7ceacde

File tree

3 files changed

+8
-23
lines changed

3 files changed

+8
-23
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
108108
backend.xnnpack.enabled=true \
109109
backend.xnnpack.extended_ops=true \
110110
export.output_name="${MODEL_SEPARATE}.pte" \
111-
serialization.foundation_weights_file="${MODEL_SEPARATE}.ptd"
111+
export.foundation_weights_file="${MODEL_SEPARATE}.ptd"
112112

113113
# Run llama runner.
114114
NOW=$(date +"%H:%M:%S")

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,9 +1078,9 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10781078
llm_config.backend.xnnpack.enabled = True
10791079

10801080
if llm_config.backend.xnnpack.enabled:
1081-
if llm_config.serialization.foundation_weights_file is not None:
1081+
if llm_config.export.foundation_weights_file is not None:
10821082
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
1083-
llm_config.serialization.foundation_weights_file
1083+
llm_config.export.foundation_weights_file
10841084
if "lora" not in x.name
10851085
else None
10861086
)

extension/llm/export/config/llm_config.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ class ExportConfig:
211211
so_library: Shared library to specify custom quantized operators.
212212
export_only: Whether to stop right after torch.export() and
213213
just save the exported .pt2 graph file.
214+
foundation_weights_file: configure the foundation weights of a model
215+
to be placed in a separate file, external to the PTE. Pass the
216+
intended file name here.
214217
"""
215218

216219
max_seq_length: int = 128
@@ -219,6 +222,7 @@ class ExportConfig:
219222
output_name: Optional[str] = None
220223
so_library: Optional[str] = None
221224
export_only: bool = False
225+
foundation_weights_file: Optional[str] = None
222226

223227
def __post_init__(self):
224228
if self.max_context_length < self.max_seq_length:
@@ -227,20 +231,6 @@ def __post_init__(self):
227231
)
228232

229233

230-
@dataclass
231-
class SerializationConfig:
232-
"""
233-
Configures properties relevant to the serialization process.
234-
235-
Attributes:
236-
foundation_weights_file: configure the foundation weights of a model
237-
to be placed in a separate file, external to the PTE. Pass the
238-
intended file name here.
239-
"""
240-
241-
foundation_weights_file: Optional[str] = None
242-
243-
244234
################################################################################
245235
################################# DebugConfig ##################################
246236
################################################################################
@@ -480,7 +470,6 @@ class LlmConfig:
480470
base: BaseConfig = field(default_factory=BaseConfig)
481471
model: ModelConfig = field(default_factory=ModelConfig)
482472
export: ExportConfig = field(default_factory=ExportConfig)
483-
serialization: SerializationConfig = field(default_factory=SerializationConfig)
484473
debug: DebugConfig = field(default_factory=DebugConfig)
485474
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
486475
backend: BackendConfig = field(default_factory=BackendConfig)
@@ -560,12 +549,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
560549
llm_config.export.so_library = args.so_library
561550
if hasattr(args, "export_only"):
562551
llm_config.export.export_only = args.export_only
563-
564-
# SerializationConfig
565552
if hasattr(args, "foundation_weights_file"):
566-
llm_config.serialization.foundation_weights_file = (
567-
args.foundation_weights_file
568-
)
553+
llm_config.export.foundation_weights_file = args.foundation_weights_file
569554

570555
# QuantizationConfig
571556
if hasattr(args, "quantization_mode"):

0 commit comments

Comments
 (0)