Skip to content

Commit a7fe5f5

Browse files
committed
Export lora weights to sep file
Pull Request resolved: #14756 ghstack-source-id: 314404070 @exported-using-ghexport Differential Revision: [D83777195](https://our.internmc.facebook.com/intern/diff/D83777195/)
1 parent 9a34b1d commit a7fe5f5

File tree

4 files changed

+47
-16
lines changed

4 files changed

+47
-16
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ else
9595
fi
9696

9797
# Export LoRA PTE, PTD file.
98-
MODEL_SEPARATE="${MODEL_NAME}_separate"
98+
MODEL_PROGRAM_ONLY="${MODEL_NAME}_program"
99+
MODEL_LORA_WEIGHTS="lora_weights"
100+
MODEL_FOUNDATION_WEIGHTS="foundation_weights"
99101
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
100102
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
101103
base.params="${DOWNLOADED_PATH}/params.json" \
@@ -107,14 +109,15 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
107109
model.dtype_override="fp32" \
108110
backend.xnnpack.enabled=true \
109111
backend.xnnpack.extended_ops=true \
110-
export.output_name="${MODEL_SEPARATE}.pte" \
111-
export.foundation_weights_file="${MODEL_SEPARATE}.ptd"
112+
export.output_name="${MODEL_PROGRAM_ONLY}.pte" \
113+
export.foundation_weights_file="${MODEL_FOUNDATION_WEIGHTS}.ptd" \
114+
export.lora_weights_file="${MODEL_LORA_WEIGHTS}.ptd"
112115

113116
# Run llama runner.
114117
NOW=$(date +"%H:%M:%S")
115118
echo "Starting to run llama runner at ${NOW}"
116119
# shellcheck source=/dev/null
117-
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
120+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_PROGRAM_ONLY}.pte --data_paths="${MODEL_FOUNDATION_WEIGHTS}.ptd,${MODEL_LORA_WEIGHTS}.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
118121
NOW=$(date +"%H:%M:%S")
119122
echo "Finished at ${NOW}"
120123

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,13 +1088,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10881088
llm_config.backend.xnnpack.enabled = True
10891089

10901090
if llm_config.backend.xnnpack.enabled:
1091-
if llm_config.export.foundation_weights_file is not None:
1091+
if (
1092+
llm_config.export.foundation_weights_file is not None
1093+
or llm_config.export.lora_weights_file is not None
1094+
):
10921095
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
10931096
llm_config.export.foundation_weights_file
10941097
if "lora" not in x.name
1095-
else None
1098+
else llm_config.export.lora_weights_file
10961099
)
1097-
10981100
from executorch.exir.passes.external_constants_pass import (
10991101
delegate_external_constants_pass_unlifted,
11001102
external_constants_pass,

examples/models/llama/main.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
*/
99

1010
#include <gflags/gflags.h>
11+
#include <sstream>
12+
#include <vector>
1113

1214
#include <executorch/examples/models/llama/runner/runner.h>
1315

@@ -21,7 +23,10 @@ DEFINE_string(
2123
"llama2.pte",
2224
"Model serialized in flatbuffer format.");
2325

24-
DEFINE_string(data_path, "", "Data file for the model.");
26+
DEFINE_string(
27+
data_paths,
28+
"",
29+
"Data files for the model. If multiple files are provided, they should be comma separated.");
2530

2631
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
2732

@@ -54,6 +59,26 @@ DEFINE_int32(
5459

5560
DEFINE_bool(warmup, false, "Whether to run a warmup run.");
5661

62+
// Helper function to parse comma-separated string lists
63+
std::vector<std::string> parseStringList(const std::string& input) {
64+
std::vector<std::string> result;
65+
if (input.empty()) {
66+
return result;
67+
}
68+
69+
std::stringstream ss(input);
70+
std::string item;
71+
while (std::getline(ss, item, ',')) {
72+
// Trim whitespace
73+
item.erase(0, item.find_first_not_of(" \t"));
74+
item.erase(item.find_last_not_of(" \t") + 1);
75+
if (!item.empty()) {
76+
result.push_back(item);
77+
}
78+
}
79+
return result;
80+
}
81+
5782
int32_t main(int32_t argc, char** argv) {
5883
gflags::ParseCommandLineFlags(&argc, &argv, true);
5984

@@ -62,10 +87,7 @@ int32_t main(int32_t argc, char** argv) {
6287
// and users can create their own DataLoaders to load from arbitrary sources.
6388
const char* model_path = FLAGS_model_path.c_str();
6489

65-
std::optional<std::string> data_path = std::nullopt;
66-
if (!FLAGS_data_path.empty()) {
67-
data_path = FLAGS_data_path.c_str();
68-
}
90+
std::vector<std::string> data_paths = parseStringList(FLAGS_data_paths);
6991

7092
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
7193

@@ -92,7 +114,7 @@ int32_t main(int32_t argc, char** argv) {
92114
#endif
93115
// create llama runner
94116
std::unique_ptr<::executorch::extension::llm::TextLLMRunner> runner =
95-
example::create_llama_runner(model_path, tokenizer_path, data_path);
117+
example::create_llama_runner(model_path, tokenizer_path, data_paths);
96118

97119
if (runner == nullptr) {
98120
ET_LOG(Error, "Failed to create llama runner");

extension/llm/export/config/llm_config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,10 @@ class ExportConfig:
215215
so_library: Shared library to specify custom quantized operators.
216216
export_only: Whether to stop right after torch.export() and
217217
just save the exported .pt2 graph file.
218-
foundation_weights_file: configure the foundation weights of a model
219-
to be placed in a separate file, external to the PTE. Pass the
220-
intended file name here.
218+
foundation_weights_file: place the foundation weights of the model into
219+
a separate file, external to the PTE. Pass the file name here.
220+
lora_weights_file: place the lora weights of the model into a
221+
separate file, external to the PTE. Pass the file name here.
221222
"""
222223

223224
max_seq_length: int = 128
@@ -227,6 +228,7 @@ class ExportConfig:
227228
so_library: Optional[str] = None
228229
export_only: bool = False
229230
foundation_weights_file: Optional[str] = None
231+
lora_weights_file: Optional[str] = None
230232

231233
def __post_init__(self):
232234
if self.max_context_length < self.max_seq_length:
@@ -572,6 +574,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
572574
llm_config.export.export_only = args.export_only
573575
if hasattr(args, "foundation_weights_file"):
574576
llm_config.export.foundation_weights_file = args.foundation_weights_file
577+
if hasattr(args, "lora_weights_file"):
578+
llm_config.export.lora_weights_file = args.lora_weights_file
575579

576580
# QuantizationConfig
577581
if hasattr(args, "quantization_mode"):

0 commit comments

Comments
 (0)