Skip to content

Commit 650b32f

Browse files
pytorchbotlucylq
andauthored
Save foundation weights separately (#13268)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13161 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/99/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/99/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/99/orig @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent 2d4533a commit 650b32f

File tree

8 files changed

+108
-19
lines changed

8 files changed

+108
-19
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,17 @@ DOWNLOADED_PATH=$(
4848
--model_id "${HF_MODEL_REPO}" \
4949
--files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model"
5050
)
51-
EXPORTED_MODEL_NAME="llama_3_2_1B_lora.pte"
52-
# Export model.
51+
# Build llama runner.
52+
cmake_install_executorch_libraries
53+
cmake_build_llama_runner
54+
55+
# Constants.
56+
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
57+
PROMPT="What happens if you eat watermelon seeds?"
58+
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
59+
60+
# Export LoRA PTE file.
61+
MODEL_NAME="llama_3_2_1B_lora"
5362
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
5463
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
5564
base.params="${DOWNLOADED_PATH}/params.json" \
@@ -61,36 +70,64 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
6170
model.dtype_override="fp32" \
6271
backend.xnnpack.enabled=true \
6372
backend.xnnpack.extended_ops=true \
64-
export.output_name="${EXPORTED_MODEL_NAME}"
65-
66-
# Build llama runner.
67-
cmake_install_executorch_libraries
68-
cmake_build_llama_runner
73+
export.output_name="${MODEL_NAME}.pte"
6974

70-
PROMPT="What happens if you eat watermelon seeds?"
7175
# Run llama runner
72-
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
73-
7476
NOW=$(date +"%H:%M:%S")
7577
echo "Starting to run llama runner at ${NOW}"
7678
# shellcheck source=/dev/null
77-
cmake-out/examples/models/llama/llama_main --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
79+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
7880
NOW=$(date +"%H:%M:%S")
7981
echo "Finished at ${NOW}"
8082

8183
RESULT=$(cat result.txt)
82-
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
83-
8484
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
8585
echo "Expected result prefix: ${EXPECTED_PREFIX}"
8686
echo "Actual result: ${RESULT}"
87+
# Do not clean up files if test passes, as they're re-used in the next test.
8788
echo "Success"
88-
cleanup_files
8989
else
9090
echo "Expected result prefix: ${EXPECTED_PREFIX}"
9191
echo "Actual result: ${RESULT}"
9292
echo "Failure; results not the same"
93+
cleanup_files
94+
exit 1
95+
fi
9396

97+
# Export LoRA PTE, PTD file.
98+
MODEL_SEPARATE="${MODEL_NAME}_separate"
99+
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
100+
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
101+
base.params="${DOWNLOADED_PATH}/params.json" \
102+
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
103+
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
104+
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
105+
model.use_kv_cache=true \
106+
model.use_sdpa_with_kv_cache=true \
107+
model.dtype_override="fp32" \
108+
backend.xnnpack.enabled=true \
109+
backend.xnnpack.extended_ops=true \
110+
export.output_name="${MODEL_SEPARATE}.pte" \
111+
export.foundation_weights_file="${MODEL_SEPARATE}.ptd"
112+
113+
# Run llama runner.
114+
NOW=$(date +"%H:%M:%S")
115+
echo "Starting to run llama runner at ${NOW}"
116+
# shellcheck source=/dev/null
117+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
118+
NOW=$(date +"%H:%M:%S")
119+
echo "Finished at ${NOW}"
120+
121+
RESULT2=$(cat result2.txt)
122+
if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then
123+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
124+
echo "Actual result: ${RESULT2}"
125+
echo "Success"
126+
cleanup_files
127+
else
128+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
129+
echo "Actual result: ${RESULT2}"
130+
echo "Failure; results not the same"
94131
cleanup_files
95132
exit 1
96133
fi

backends/xnnpack/operators/node_visitor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,12 @@ def get_serialized_buffer_index(
621621
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
622622
)
623623

624-
external_tag = tensor.meta.get("delegate_constant_tag", None)
624+
custom_meta = tensor.meta.get("custom", None)
625+
external_tag = (
626+
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
627+
)
625628
if external_tag is not None:
629+
external_tag = custom_meta.get("delegate_constant_tag", None)
626630
logging.info(
627631
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
628632
)

examples/models/llama/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ runtime.python_library(
153153
"//caffe2:torch",
154154
"//executorch/extension/llm/export/config:llm_config",
155155
"//executorch/backends/vulkan/_passes:vulkan_passes",
156+
"//executorch/exir/passes:external_constants_pass",
156157
"//executorch/exir/passes:init_mutable_pass",
157158
"//executorch/examples/models:model_base",
158159
"//executorch/examples/models:models",
159-
"//executorch/exir/passes:init_mutable_pass",
160160
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
161161
"//executorch/extension/llm/export:export_lib",
162162
# one definition has to be included in the user of the libarary

examples/models/llama/export_llama_lib.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10781078
llm_config.backend.xnnpack.enabled = True
10791079

10801080
if llm_config.backend.xnnpack.enabled:
1081+
if llm_config.export.foundation_weights_file is not None:
1082+
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
1083+
llm_config.export.foundation_weights_file
1084+
if "lora" not in x.name
1085+
else None
1086+
)
1087+
1088+
from executorch.exir.passes.external_constants_pass import (
1089+
delegate_external_constants_pass_unlifted,
1090+
)
1091+
1092+
delegate_external_constants_pass_unlifted(
1093+
gm=builder_exported.pre_autograd_graph_module,
1094+
gen_tag_fn=gen_tag_fn,
1095+
)
1096+
10811097
builder = _to_edge_and_lower_llama_xnnpack(
10821098
builder_exported,
10831099
modelname,

exir/passes/external_constants_pass.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,28 @@ def delegate_external_constants_pass(
113113
for node in module.graph.nodes:
114114
if node.op == "placeholder" and is_param_node(ep, node):
115115
if gen_tag_fn is not None:
116-
node.meta["delegate_constant_tag"] = gen_tag_fn(node)
116+
node.meta.setdefault("custom", {})
117+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
118+
mutated = True
119+
return PassResult(gm, mutated)
120+
121+
122+
# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
123+
# and not on a lifted graph, e.g. ep.graph_module.
124+
# This is using 'get_attr' to tag constants, which only appears in
125+
# unlifted graphs.
126+
def delegate_external_constants_pass_unlifted(
127+
gm: GraphModule,
128+
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
129+
) -> PassResult:
130+
mutated = False
131+
for module in gm.modules():
132+
if not isinstance(module, torch.fx.GraphModule):
133+
continue
134+
for node in module.graph.nodes:
135+
if node.op == "get_attr":
136+
if gen_tag_fn is not None:
137+
node.meta.setdefault("custom", {})
138+
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
117139
mutated = True
118140
return PassResult(gm, mutated)

exir/program/_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1908,7 +1908,9 @@ def write_tensor_data_to_file(self, outdir) -> None:
19081908
"""
19091909
assert self._tensor_data is not None
19101910
for filename, cord in self._tensor_data.items():
1911-
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
1911+
if not filename.endswith(".ptd"):
1912+
filename += ".ptd"
1913+
with open(os.path.join(outdir, f"{filename}"), "wb") as f:
19121914
logging.info(f"Writing data file to {filename}")
19131915
cord.write_to_file(f)
19141916

extension/llm/export/config/llm_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ class ExportConfig:
211211
so_library: Shared library to specify custom quantized operators.
212212
export_only: Whether to stop right after torch.export() and
213213
just save the exported .pt2 graph file.
214+
foundation_weights_file: configure the foundation weights of a model
215+
to be placed in a separate file, external to the PTE. Pass the
216+
intended file name here.
214217
"""
215218

216219
max_seq_length: int = 128
@@ -219,6 +222,7 @@ class ExportConfig:
219222
output_name: Optional[str] = None
220223
so_library: Optional[str] = None
221224
export_only: bool = False
225+
foundation_weights_file: Optional[str] = None
222226

223227
def __post_init__(self):
224228
if self.max_context_length < self.max_seq_length:
@@ -545,6 +549,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
545549
llm_config.export.so_library = args.so_library
546550
if hasattr(args, "export_only"):
547551
llm_config.export.export_only = args.export_only
552+
if hasattr(args, "foundation_weights_file"):
553+
llm_config.export.foundation_weights_file = args.foundation_weights_file
548554

549555
# QuantizationConfig
550556
if hasattr(args, "quantization_mode"):

runtime/executor/merged_data_map.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ class MergedDataMap final : public NamedDataMap {
3737
// Check for duplicate keys.
3838
for (uint32_t k = 0; k < first->get_num_keys().get(); k++) {
3939
const auto key = first->get_key(k).get();
40+
const auto error = second->get_tensor_layout(key).error();
41+
// TODO(lfq): add API to check if key exists.
4042
ET_CHECK_OR_RETURN_ERROR(
41-
second->get_tensor_layout(key).error() == Error::NotFound,
43+
error == Error::NotFound || error == Error::NotImplemented,
4244
InvalidArgument,
4345
"Duplicate key %s.",
4446
key);

0 commit comments

Comments
 (0)