Skip to content

Commit e042184

Browse files
committed
Update on "Save foundation weights separately"
This diff: 1. Introduces SerializationConfig to llm_config. Currently, this allows user to save the foundation weights in a separate file; majorly useful for lora case. 2. Adds a pass to tag foundation (non-lora) weights. This is at the top-level (export_llama_lib). The tags are preserved through run_decomps/other passes, and do not affect functionality. 3. Tags are read when placing constants into the named_data_store. 4. Tagged weights are serialized to a separate file. Notes 1. Adding tags to node.meta['custom']['blah'] means that they will not be discarded by run_decompositions 2. Adding tags to the lifted model (ep.graph_module) requires the EP to check is_param_node for xnnpack constants. Instead, add tags to the unlifted model (ep.module()), so we do not need to go through a re-export to get the EP. 3. Not an issue for this diff as llama doesn't have any higher order ops. Adding tags to models with higher-order ops is problematic due to nested submodules. Differential Revision: [D79181064](https://our.internmc.facebook.com/intern/diff/D79181064/) [ghstack-poisoned]
2 parents 913a8b7 + 2b9a48e commit e042184

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ RESULT=$(cat result.txt)
8484
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
8585
echo "Expected result prefix: ${EXPECTED_PREFIX}"
8686
echo "Actual result: ${RESULT}"
87+
# Do not clean up files if test passes, as they're re-used in the next test.
8788
echo "Success"
88-
cleanup_files
8989
else
9090
echo "Expected result prefix: ${EXPECTED_PREFIX}"
9191
echo "Actual result: ${RESULT}"

backends/xnnpack/operators/node_visitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,9 @@ def get_serialized_buffer_index(
622622
)
623623

624624
custom_meta = tensor.meta.get("custom", None)
625-
external_tag = custom_meta.get("delegate_constant_tag", None) if custom_meta else None
625+
external_tag = (
626+
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
627+
)
626628
if external_tag is not None:
627629
external_tag = custom_meta.get("delegate_constant_tag", None)
628630
logging.info(

0 commit comments

Comments
 (0)