Skip to content

Commit 39caaa1

Browse files
committed
Save foundation weights separately
Pull Request resolved: #13161 This diff: 1. Introduces SerializationConfig to llm_config. Currently, this allows user to save the foundation weights in a separate file; majorly useful for lora case. 2. Adds a pass to tag foundation (non-lora) weights. This is at the top-level (export_llama_lib). The tags are preserved through run_decomps/other passes, and do not affect functionality. 3. Tags are read when placing constants into the named_data_store. 4. Tagged weights are serialized to a separate file. Notes 1. Adding tags to node.meta['custom']['blah'] means that they will not be discarded by run_decompositions 2. Adding tags to the lifted model (ep.graph_module) requires the EP to check is_param_node for xnnpack constants. Instead, add tags to the unlifted model (ep.module()), so we do not need to go through a re-export to get the EP. 3. Not an issue for this diff as llama doesn't have any higher order ops. Adding tags to models with higher-order ops is problematic due to nested submodules. ghstack-source-id: 301151822 @exported-using-ghexport Differential Revision: [D79181064](https://our.internmc.facebook.com/intern/diff/D79181064/)
1 parent bbb913b commit 39caaa1

File tree

7 files changed

+120
-18
lines changed

7 files changed

+120
-18
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,17 @@ DOWNLOADED_PATH=$(
4848
--model_id "${HF_MODEL_REPO}" \
4949
--files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model"
5050
)
51-
EXPORTED_MODEL_NAME="llama_3_2_1B_lora.pte"
52-
# Export model.
51+
# Build llama runner.
52+
cmake_install_executorch_libraries
53+
cmake_build_llama_runner
54+
55+
# Constants.
56+
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
57+
PROMPT="What happens if you eat watermelon seeds?"
58+
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
59+
60+
# Export LoRA PTE file.
61+
MODEL_NAME="llama_3_2_1B_lora"
5362
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
5463
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
5564
base.params="${DOWNLOADED_PATH}/params.json" \
@@ -61,36 +70,64 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
6170
model.dtype_override="fp32" \
6271
backend.xnnpack.enabled=true \
6372
backend.xnnpack.extended_ops=true \
64-
export.output_name="${EXPORTED_MODEL_NAME}"
65-
66-
# Build llama runner.
67-
cmake_install_executorch_libraries
68-
cmake_build_llama_runner
73+
export.output_name="${MODEL_NAME}.pte"
6974

70-
PROMPT="What happens if you eat watermelon seeds?"
7175
# Run llama runner
72-
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
73-
7476
NOW=$(date +"%H:%M:%S")
7577
echo "Starting to run llama runner at ${NOW}"
7678
# shellcheck source=/dev/null
77-
cmake-out/examples/models/llama/llama_main --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
79+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
7880
NOW=$(date +"%H:%M:%S")
7981
echo "Finished at ${NOW}"
8082

8183
RESULT=$(cat result.txt)
82-
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
83-
8484
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
8585
echo "Expected result prefix: ${EXPECTED_PREFIX}"
8686
echo "Actual result: ${RESULT}"
87+
# Do not clean up files if test passes, as they're re-used in the next test.
8788
echo "Success"
88-
cleanup_files
8989
else
9090
echo "Expected result prefix: ${EXPECTED_PREFIX}"
9191
echo "Actual result: ${RESULT}"
9292
echo "Failure; results not the same"
93+
cleanup_files
94+
exit 1
95+
fi
9396

97+
# Export LoRA PTE, PTD file.
98+
MODEL_SEPARATE="${MODEL_NAME}_separate"
99+
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
100+
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
101+
base.params="${DOWNLOADED_PATH}/params.json" \
102+
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
103+
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
104+
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
105+
model.use_kv_cache=true \
106+
model.use_sdpa_with_kv_cache=true \
107+
model.dtype_override="fp32" \
108+
backend.xnnpack.enabled=true \
109+
backend.xnnpack.extended_ops=true \
110+
export.output_name="${MODEL_SEPARATE}.pte" \
111+
serialization.foundation_weights_file="${MODEL_SEPARATE}.ptd"
112+
113+
# Run llama runner.
114+
NOW=$(date +"%H:%M:%S")
115+
echo "Starting to run llama runner at ${NOW}"
116+
# shellcheck source=/dev/null
117+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
118+
NOW=$(date +"%H:%M:%S")
119+
echo "Finished at ${NOW}"
120+
121+
RESULT2=$(cat result2.txt)
122+
if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then
123+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
124+
echo "Actual result: ${RESULT2}"
125+
echo "Success"
126+
cleanup_files
127+
else
128+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
129+
echo "Actual result: ${RESULT2}"
130+
echo "Failure; results not the same"
94131
cleanup_files
95132
exit 1
96133
fi

backends/xnnpack/operators/node_visitor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,12 @@ def get_serialized_buffer_index(
621621
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
622622
)
623623

624-
external_tag = tensor.meta.get("delegate_constant_tag", None)
624+
custom_meta = tensor.meta.get("custom", None)
625+
external_tag = (
626+
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
627+
)
625628
if external_tag is not None:
629+
external_tag = custom_meta.get("delegate_constant_tag", None)
626630
logging.info(
627631
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
628632
)

examples/models/llama/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ runtime.python_library(
153153
"//caffe2:torch",
154154
"//executorch/extension/llm/export/config:llm_config",
155155
"//executorch/backends/vulkan/_passes:vulkan_passes",
156+
"//executorch/exir/passes:external_constants_pass",
156157
"//executorch/exir/passes:init_mutable_pass",
157158
"//executorch/examples/models:model_base",
158159
"//executorch/examples/models:models",
159-
"//executorch/exir/passes:init_mutable_pass",
160160
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
161161
"//executorch/extension/llm/export:export_lib",
162162
# one definition has to be included in the user of the libarary

examples/models/llama/export_llama_lib.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10781078
llm_config.backend.xnnpack.enabled = True
10791079

10801080
if llm_config.backend.xnnpack.enabled:
1081+
if llm_config.serialization.foundation_weights_file is not None:
1082+
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
1083+
llm_config.serialization.foundation_weights_file
1084+
if "lora" not in x.name
1085+
else None
1086+
)
1087+
1088+
from executorch.exir.passes.external_constants_pass import (
1089+
delegate_external_constants_pass_unlifted,
1090+
)
1091+
1092+
delegate_external_constants_pass_unlifted(
1093+
gm=builder_exported.pre_autograd_graph_module,
1094+
gen_tag_fn=gen_tag_fn,
1095+
)
1096+
10811097
builder = _to_edge_and_lower_llama_xnnpack(
10821098
builder_exported,
10831099
modelname,

exir/passes/external_constants_pass.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,28 @@ def delegate_external_constants_pass(
113113
for node in module.graph.nodes:
114114
if node.op == "placeholder" and is_param_node(ep, node):
115115
if gen_tag_fn is not None:
116-
node.meta["delegate_constant_tag"] = gen_tag_fn(node)
116+
node.meta.setdefault("custom", {})
117+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
118+
mutated = True
119+
return PassResult(gm, mutated)
120+
121+
122+
# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
123+
# and not on a lifted graph, e.g. ep.graph_module.
124+
# This is using 'get_attr' to tag constants, which only appears in
125+
# unlifted graphs.
126+
def delegate_external_constants_pass_unlifted(
127+
gm: GraphModule,
128+
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
129+
) -> PassResult:
130+
mutated = False
131+
for module in gm.modules():
132+
if not isinstance(module, torch.fx.GraphModule):
133+
continue
134+
for node in module.graph.nodes:
135+
if node.op == "get_attr":
136+
if gen_tag_fn is not None:
137+
node.meta.setdefault("custom", {})
138+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
117139
mutated = True
118140
return PassResult(gm, mutated)

extension/llm/export/config/llm_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,20 @@ def __post_init__(self):
227227
)
228228

229229

230+
@dataclass
231+
class SerializationConfig:
232+
"""
233+
Configures properties relevant to the serialization process.
234+
235+
Attributes:
236+
foundation_weights_file: configure the foundation weights of a model
237+
to be placed in a separate file, external to the PTE. Pass the
238+
intended file name here.
239+
"""
240+
241+
foundation_weights_file: Optional[str] = None
242+
243+
230244
################################################################################
231245
################################# DebugConfig ##################################
232246
################################################################################
@@ -466,6 +480,7 @@ class LlmConfig:
466480
base: BaseConfig = field(default_factory=BaseConfig)
467481
model: ModelConfig = field(default_factory=ModelConfig)
468482
export: ExportConfig = field(default_factory=ExportConfig)
483+
serialization: SerializationConfig = field(default_factory=SerializationConfig)
469484
debug: DebugConfig = field(default_factory=DebugConfig)
470485
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
471486
backend: BackendConfig = field(default_factory=BackendConfig)
@@ -546,6 +561,12 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
546561
if hasattr(args, "export_only"):
547562
llm_config.export.export_only = args.export_only
548563

564+
# SerializationConfig
565+
if hasattr(args, "foundation_weights_file"):
566+
llm_config.serialization.foundation_weights_file = (
567+
args.foundation_weights_file
568+
)
569+
549570
# QuantizationConfig
550571
if hasattr(args, "quantization_mode"):
551572
llm_config.quantization.qmode = args.quantization_mode

runtime/executor/merged_data_map.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ class MergedDataMap final : public NamedDataMap {
3737
// Check for duplicate keys.
3838
for (uint32_t k = 0; k < first->get_num_keys().get(); k++) {
3939
const auto key = first->get_key(k).get();
40+
const auto error = second->get_tensor_layout(key).error();
41+
// TODO(lfq): add API to check if key exists.
4042
ET_CHECK_OR_RETURN_ERROR(
41-
second->get_tensor_layout(key).error() == Error::NotFound,
43+
error == Error::NotFound || error == Error::NotImplemented,
4244
InvalidArgument,
4345
"Duplicate key %s.",
4446
key);

0 commit comments

Comments
 (0)