Skip to content

Commit a516927

Browse files
committed
Save foundation weights separately
This diff: 1. Introduces SerializationConfig to llm_config. Currently, this allows user to save the foundation weights in a separate file; majorly useful for lora case. 2. Adds a pass to tag foundation (non-lora) weights. This is at the top-level (export_llama_lib). The tags are preserved through run_decomps/other passes, and do not affect functionality. 3. Tags are read when placing constants into the named_data_store. 4. Tagged weights are serialized to a separate file. Notes 1. Adding tags to node.meta['custom']['blah'] means that they will not be discarded by run_decompositions 2. Adding tags to the lifted model (ep.graph_module) requires the EP to check is_param_node for xnnpack constants. Instead, add tags to the unlifted model (ep.module()), so we do not need to go through a re-export to get the EP. 3. Not an issue for this diff as llama doesn't have any higher order ops. Adding tags to models with higher-order ops is problematic due to nested submodules. Differential Revision: [D79181064](https://our.internmc.facebook.com/intern/diff/D79181064/) [ghstack-poisoned]
1 parent 6bc312a commit a516927

File tree

7 files changed

+118
-17
lines changed

7 files changed

+118
-17
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,17 @@ DOWNLOADED_PATH=$(
4848
--model_id "${HF_MODEL_REPO}" \
4949
--files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model"
5050
)
51-
EXPORTED_MODEL_NAME="llama_3_2_1B_lora.pte"
52-
# Export model.
51+
# Build llama runner.
52+
cmake_install_executorch_libraries
53+
cmake_build_llama_runner
54+
55+
# Constants.
56+
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
57+
PROMPT="What happens if you eat watermelon seeds?"
58+
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
59+
60+
# Export LoRA PTE file.
61+
MODEL_NAME="llama_3_2_1B_lora"
5362
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
5463
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
5564
base.params="${DOWNLOADED_PATH}/params.json" \
@@ -61,26 +70,17 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
6170
model.dtype_override="fp32" \
6271
backend.xnnpack.enabled=true \
6372
backend.xnnpack.extended_ops=true \
64-
export.output_name="${EXPORTED_MODEL_NAME}"
65-
66-
# Build llama runner.
67-
cmake_install_executorch_libraries
68-
cmake_build_llama_runner
73+
export.output_name="${MODEL_NAME}.pte"
6974

70-
PROMPT="What happens if you eat watermelon seeds?"
7175
# Run llama runner
72-
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
73-
7476
NOW=$(date +"%H:%M:%S")
7577
echo "Starting to run llama runner at ${NOW}"
7678
# shellcheck source=/dev/null
77-
cmake-out/examples/models/llama/llama_main --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
79+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
7880
NOW=$(date +"%H:%M:%S")
7981
echo "Finished at ${NOW}"
8082

81-
RESULT=$(cat result.txt)
82-
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
83-
83+
RESULT=$(cat lora.txt)
8484
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
8585
echo "Expected result prefix: ${EXPECTED_PREFIX}"
8686
echo "Actual result: ${RESULT}"
@@ -90,7 +90,44 @@ else
9090
echo "Expected result prefix: ${EXPECTED_PREFIX}"
9191
echo "Actual result: ${RESULT}"
9292
echo "Failure; results not the same"
93+
cleanup_files
94+
exit 1
95+
fi
96+
97+
# Export LoRA PTE, PTD file.
98+
MODEL_SEPARATE="${MODEL_NAME}_separate"
99+
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
100+
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
101+
base.params="${DOWNLOADED_PATH}/params.json" \
102+
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
103+
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
104+
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
105+
model.use_kv_cache=true \
106+
model.use_sdpa_with_kv_cache=true \
107+
model.dtype_override="fp32" \
108+
backend.xnnpack.enabled=true \
109+
backend.xnnpack.extended_ops=true \
110+
export.output_name="${MODEL_SEPARATE}.pte" \
111+
serialization.foundation_weights_file="${MODEL_SEPARATE}.ptd"
93112

113+
# Run llama runner.
114+
NOW=$(date +"%H:%M:%S")
115+
echo "Starting to run llama runner at ${NOW}"
116+
# shellcheck source=/dev/null
117+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
118+
NOW=$(date +"%H:%M:%S")
119+
echo "Finished at ${NOW}"
120+
121+
RESULT2=$(cat result2.txt)
122+
if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then
123+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
124+
echo "Actual result: ${RESULT2}"
125+
echo "Success"
126+
cleanup_files
127+
else
128+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
129+
echo "Actual result: ${RESULT2}"
130+
echo "Failure; results not the same"
94131
cleanup_files
95132
exit 1
96133
fi

backends/xnnpack/operators/node_visitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,10 @@ def get_serialized_buffer_index(
621621
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
622622
)
623623

624-
external_tag = tensor.meta.get("delegate_constant_tag", None)
624+
custom_meta = tensor.meta.get("custom", None)
625+
external_tag = custom_meta.get("delegate_constant_tag", None) if custom_meta else None
625626
if external_tag is not None:
627+
external_tag = custom_meta.get("delegate_constant_tag", None)
626628
logging.info(
627629
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
628630
)

examples/models/llama/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ runtime.python_library(
153153
"//caffe2:torch",
154154
"//executorch/extension/llm/export/config:llm_config",
155155
"//executorch/backends/vulkan/_passes:vulkan_passes",
156+
"//executorch/exir/passes:external_constants_pass",
156157
"//executorch/exir/passes:init_mutable_pass",
157158
"//executorch/examples/models:model_base",
158159
"//executorch/examples/models:models",
159-
"//executorch/exir/passes:init_mutable_pass",
160160
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
161161
"//executorch/extension/llm/export:export_lib",
162162
# one definition has to be included in the user of the libarary

examples/models/llama/export_llama_lib.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10781078
llm_config.backend.xnnpack.enabled = True
10791079

10801080
if llm_config.backend.xnnpack.enabled:
1081+
if llm_config.serialization.foundation_weights_file is not None:
1082+
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
1083+
llm_config.serialization.foundation_weights_file
1084+
if "lora" not in x.name
1085+
else None
1086+
)
1087+
1088+
from executorch.exir.passes.external_constants_pass import (
1089+
delegate_external_constants_pass_unlifted,
1090+
)
1091+
1092+
delegate_external_constants_pass_unlifted(
1093+
gm=builder_exported.pre_autograd_graph_module,
1094+
gen_tag_fn=gen_tag_fn,
1095+
)
1096+
10811097
builder = _to_edge_and_lower_llama_xnnpack(
10821098
builder_exported,
10831099
modelname,

exir/passes/external_constants_pass.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,28 @@ def delegate_external_constants_pass(
113113
for node in module.graph.nodes:
114114
if node.op == "placeholder" and is_param_node(ep, node):
115115
if gen_tag_fn is not None:
116-
node.meta["delegate_constant_tag"] = gen_tag_fn(node)
116+
node.meta.setdefault("custom", {})
117+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
118+
mutated = True
119+
return PassResult(gm, mutated)
120+
121+
122+
# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
123+
# and not on a lifted graph, e.g. ep.graph_module.
124+
# This is using 'get_attr' to tag constants, which only appears in
125+
# unlifted graphs.
126+
def delegate_external_constants_pass_unlifted(
127+
gm: GraphModule,
128+
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
129+
) -> PassResult:
130+
mutated = False
131+
for module in gm.modules():
132+
if not isinstance(module, torch.fx.GraphModule):
133+
continue
134+
for node in module.graph.nodes:
135+
if node.op == "get_attr":
136+
if gen_tag_fn is not None:
137+
node.meta.setdefault("custom", {})
138+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
117139
mutated = True
118140
return PassResult(gm, mutated)

extension/llm/export/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ def save_to_pte(self, output_name: str) -> None:
541541
filename = save_pte_program(self.export_program, output_name, self.output_dir)
542542
self._saved_pte_filename = filename
543543

544+
# If there are PTD files.
545+
self.export_program.write_tensor_data_to_file(self.output_dir)
546+
544547
def get_saved_pte_filename(self) -> Optional[str]:
545548
"""
546549
Return the filename of the most recenet saved .pte file. Return None if the model is not saved.

extension/llm/export/config/llm_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,20 @@ def __post_init__(self):
227227
)
228228

229229

230+
@dataclass
231+
class SerializationConfig:
232+
"""
233+
Configures properties relevant to the serialization process.
234+
235+
Attributes:
236+
foundation_weights_file: configure the foundation weights of a model
237+
to be placed in a separate file, external to the PTE. Pass the
238+
intended file name here.
239+
"""
240+
241+
foundation_weights_file: Optional[str] = None
242+
243+
230244
################################################################################
231245
################################# DebugConfig ##################################
232246
################################################################################
@@ -466,6 +480,7 @@ class LlmConfig:
466480
base: BaseConfig = field(default_factory=BaseConfig)
467481
model: ModelConfig = field(default_factory=ModelConfig)
468482
export: ExportConfig = field(default_factory=ExportConfig)
483+
serialization: SerializationConfig = field(default_factory=SerializationConfig)
469484
debug: DebugConfig = field(default_factory=DebugConfig)
470485
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
471486
backend: BackendConfig = field(default_factory=BackendConfig)
@@ -546,6 +561,12 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
546561
if hasattr(args, "export_only"):
547562
llm_config.export.export_only = args.export_only
548563

564+
# SerializationConfig
565+
if hasattr(args, "foundation_weights_file"):
566+
llm_config.serialization.foundation_weights_file = (
567+
args.foundation_weights_file
568+
)
569+
549570
# QuantizationConfig
550571
if hasattr(args, "quantization_mode"):
551572
llm_config.quantization.qmode = args.quantization_mode

0 commit comments

Comments
 (0)