Skip to content

Save foundation weights separately #13268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions .ci/scripts/test_llama_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,17 @@ DOWNLOADED_PATH=$(
--model_id "${HF_MODEL_REPO}" \
--files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model"
)
EXPORTED_MODEL_NAME="llama_3_2_1B_lora.pte"
# Export model.
# Build llama runner.
cmake_install_executorch_libraries
cmake_build_llama_runner

# Constants.
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
PROMPT="What happens if you eat watermelon seeds?"
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"

# Export LoRA PTE file.
MODEL_NAME="llama_3_2_1B_lora"
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
base.params="${DOWNLOADED_PATH}/params.json" \
Expand All @@ -61,36 +70,64 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
model.dtype_override="fp32" \
backend.xnnpack.enabled=true \
backend.xnnpack.extended_ops=true \
export.output_name="${EXPORTED_MODEL_NAME}"

# Build llama runner.
cmake_install_executorch_libraries
cmake_build_llama_runner
export.output_name="${MODEL_NAME}.pte"

PROMPT="What happens if you eat watermelon seeds?"
# Run llama runner
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"

NOW=$(date +"%H:%M:%S")
echo "Starting to run llama runner at ${NOW}"
# shellcheck source=/dev/null
cmake-out/examples/models/llama/llama_main --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

RESULT=$(cat result.txt)
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"

if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_PREFIX}"
echo "Actual result: ${RESULT}"
# Do not clean up files if test passes, as they're re-used in the next test.
echo "Success"
cleanup_files
else
echo "Expected result prefix: ${EXPECTED_PREFIX}"
echo "Actual result: ${RESULT}"
echo "Failure; results not the same"
cleanup_files
exit 1
fi

# Export LoRA PTE, PTD file.
MODEL_SEPARATE="${MODEL_NAME}_separate"
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
base.params="${DOWNLOADED_PATH}/params.json" \
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
model.use_kv_cache=true \
model.use_sdpa_with_kv_cache=true \
model.dtype_override="fp32" \
backend.xnnpack.enabled=true \
backend.xnnpack.extended_ops=true \
export.output_name="${MODEL_SEPARATE}.pte" \
export.foundation_weights_file="${MODEL_SEPARATE}.ptd"

# Run llama runner.
NOW=$(date +"%H:%M:%S")
echo "Starting to run llama runner at ${NOW}"
# shellcheck source=/dev/null
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

RESULT2=$(cat result2.txt)
if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_PREFIX}"
echo "Actual result: ${RESULT2}"
echo "Success"
cleanup_files
else
echo "Expected result prefix: ${EXPECTED_PREFIX}"
echo "Actual result: ${RESULT2}"
echo "Failure; results not the same"
cleanup_files
exit 1
fi
6 changes: 5 additions & 1 deletion backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,12 @@ def get_serialized_buffer_index(
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
)

external_tag = tensor.meta.get("delegate_constant_tag", None)
custom_meta = tensor.meta.get("custom", None)
external_tag = (
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
)
if external_tag is not None:
external_tag = custom_meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/extension/llm/export/config:llm_config",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir/passes:external_constants_pass",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
"//executorch/extension/llm/export:export_lib",
# one definition has to be included in the user of the libarary
Expand Down
16 changes: 16 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
llm_config.backend.xnnpack.enabled = True

if llm_config.backend.xnnpack.enabled:
if llm_config.export.foundation_weights_file is not None:
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
llm_config.export.foundation_weights_file
if "lora" not in x.name
else None
)

from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass_unlifted,
)

delegate_external_constants_pass_unlifted(
gm=builder_exported.pre_autograd_graph_module,
gen_tag_fn=gen_tag_fn,
)

builder = _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
Expand Down
24 changes: 23 additions & 1 deletion exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,28 @@ def delegate_external_constants_pass(
for node in module.graph.nodes:
if node.op == "placeholder" and is_param_node(ep, node):
if gen_tag_fn is not None:
node.meta["delegate_constant_tag"] = gen_tag_fn(node)
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)


# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
# and not on a lifted graph, e.g. ep.graph_module.
# This is using 'get_attr' to tag constants, which only appears in
# unlifted graphs.
def delegate_external_constants_pass_unlifted(
gm: GraphModule,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
) -> PassResult:
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.op == "get_attr":
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)
4 changes: 3 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,9 @@ def write_tensor_data_to_file(self, outdir) -> None:
"""
assert self._tensor_data is not None
for filename, cord in self._tensor_data.items():
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
if not filename.endswith(".ptd"):
filename += ".ptd"
with open(os.path.join(outdir, f"{filename}"), "wb") as f:
logging.info(f"Writing data file to {filename}")
cord.write_to_file(f)

Expand Down
6 changes: 6 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ class ExportConfig:
so_library: Shared library to specify custom quantized operators.
export_only: Whether to stop right after torch.export() and
just save the exported .pt2 graph file.
foundation_weights_file: configure the foundation weights of a model
to be placed in a separate file, external to the PTE. Pass the
intended file name here.
"""

max_seq_length: int = 128
Expand All @@ -219,6 +222,7 @@ class ExportConfig:
output_name: Optional[str] = None
so_library: Optional[str] = None
export_only: bool = False
foundation_weights_file: Optional[str] = None

def __post_init__(self):
if self.max_context_length < self.max_seq_length:
Expand Down Expand Up @@ -545,6 +549,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.export.so_library = args.so_library
if hasattr(args, "export_only"):
llm_config.export.export_only = args.export_only
if hasattr(args, "foundation_weights_file"):
llm_config.export.foundation_weights_file = args.foundation_weights_file

# QuantizationConfig
if hasattr(args, "quantization_mode"):
Expand Down
4 changes: 3 additions & 1 deletion runtime/executor/merged_data_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class MergedDataMap final : public NamedDataMap {
// Check for duplicate keys.
for (uint32_t k = 0; k < first->get_num_keys().get(); k++) {
const auto key = first->get_key(k).get();
const auto error = second->get_tensor_layout(key).error();
// TODO(lfq): add API to check if key exists.
ET_CHECK_OR_RETURN_ERROR(
second->get_tensor_layout(key).error() == Error::NotFound,
error == Error::NotFound || error == Error::NotImplemented,
InvalidArgument,
"Duplicate key %s.",
key);
Expand Down
Loading