Skip to content

Commit b2e4d69

Browse files
committed
sync internal and oss
1 parent 45336ce commit b2e4d69

File tree

3 files changed

+82
-23
lines changed

3 files changed

+82
-23
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ else
9494
exit 1
9595
fi
9696

97-
# Export LoRA PTE, PTD file.
97+
# Export LoRA PTE, foundation PTD file.
9898
MODEL_SEPARATE="${MODEL_NAME}_separate"
9999
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
100100
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
@@ -114,20 +114,62 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
114114
NOW=$(date +"%H:%M:%S")
115115
echo "Starting to run llama runner at ${NOW}"
116116
# shellcheck source=/dev/null
117-
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
117+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_paths=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
118118
NOW=$(date +"%H:%M:%S")
119119
echo "Finished at ${NOW}"
120120

121121
RESULT2=$(cat result2.txt)
122122
if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then
123123
echo "Expected result prefix: ${EXPECTED_PREFIX}"
124124
echo "Actual result: ${RESULT2}"
125+
# Do not clean up files if test passes, as they're re-used in the next test.
125126
echo "Success"
126-
cleanup_files
127127
else
128128
echo "Expected result prefix: ${EXPECTED_PREFIX}"
129129
echo "Actual result: ${RESULT2}"
130130
echo "Failure; results not the same"
131131
cleanup_files
132132
exit 1
133133
fi
134+
135+
# Export LoRA PTE, LoRA PTD, foundation PTD file.
136+
MODEL_PROGRAM_ONLY="${MODEL_NAME}_program"
137+
MODEL_LORA_WEIGHTS="lora_weights"
138+
MODEL_FOUNDATION_WEIGHTS="foundation_weights"
139+
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
140+
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
141+
base.params="${DOWNLOADED_PATH}/params.json" \
142+
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
143+
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
144+
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
145+
model.use_kv_cache=true \
146+
model.use_sdpa_with_kv_cache=true \
147+
model.dtype_override="fp32" \
148+
backend.xnnpack.enabled=true \
149+
backend.xnnpack.extended_ops=true \
150+
export.output_name="${MODEL_PROGRAM_ONLY}.pte" \
151+
export.foundation_weights_file="${MODEL_FOUNDATION_WEIGHTS}.ptd" \
152+
export.lora_weights_file="${MODEL_LORA_WEIGHTS}.ptd"
153+
154+
# Run llama runner.
155+
NOW=$(date +"%H:%M:%S")
156+
echo "Starting to run llama runner at ${NOW}"
157+
# shellcheck source=/dev/null
158+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_PROGRAM_ONLY}.pte --data_paths="${MODEL_FOUNDATION_WEIGHTS}.ptd,${MODEL_LORA_WEIGHTS}.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result3.txt
159+
NOW=$(date +"%H:%M:%S")
160+
echo "Finished at ${NOW}"
161+
162+
RESULT3=$(cat result3.txt)
163+
if [[ "${RESULT3}" == "${EXPECTED_PREFIX}"* ]]; then
164+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
165+
echo "Actual result: ${RESULT3}"
166+
echo "Success"
167+
else
168+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
169+
echo "Actual result: ${RESULT3}"
170+
echo "Failure; results not the same"
171+
cleanup_files
172+
exit 1
173+
fi
174+
175+
cleanup_files

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,20 +1136,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11361136
llm_config.backend.xnnpack.enabled = True
11371137

11381138
if llm_config.backend.xnnpack.enabled:
1139-
if llm_config.export.foundation_weights_file is not None:
1140-
if llm_config.export.lora_weights_file is not None:
1141-
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
1142-
llm_config.export.foundation_weights_file
1143-
if "lora" not in x.name
1144-
else None
1145-
)
1146-
else:
1147-
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
1148-
llm_config.export.foundation_weights_file
1149-
if "lora" not in x.name
1150-
else llm_config.export.lora_weights_file
1151-
)
1152-
1139+
if (
1140+
llm_config.export.foundation_weights_file is not None
1141+
or llm_config.export.lora_weights_file is not None
1142+
):
1143+
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
1144+
llm_config.export.foundation_weights_file
1145+
if "lora" not in x.name
1146+
else llm_config.export.lora_weights_file
1147+
)
11531148
from executorch.exir.passes.external_constants_pass import (
11541149
delegate_external_constants_pass_unlifted,
11551150
external_constants_pass,

examples/models/llama/main.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
*/
99

1010
#include <gflags/gflags.h>
11+
#include <sstream>
12+
#include <vector>
1113

1214
#include <executorch/examples/models/llama/runner/runner.h>
1315

@@ -21,7 +23,10 @@ DEFINE_string(
2123
"llama2.pte",
2224
"Model serialized in flatbuffer format.");
2325

24-
DEFINE_string(data_path, "", "Data file for the model.");
26+
DEFINE_string(
27+
data_paths,
28+
"",
29+
"Data files for the model. If multiple files are provided, they should be comma separated.");
2530

2631
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
2732

@@ -54,6 +59,26 @@ DEFINE_int32(
5459

5560
DEFINE_bool(warmup, false, "Whether to run a warmup run.");
5661

62+
// Helper function to parse comma-separated string lists
63+
std::vector<std::string> parseStringList(const std::string& input) {
64+
std::vector<std::string> result;
65+
if (input.empty()) {
66+
return result;
67+
}
68+
69+
std::stringstream ss(input);
70+
std::string item;
71+
while (std::getline(ss, item, ',')) {
72+
// Trim whitespace
73+
item.erase(0, item.find_first_not_of(" \t"));
74+
item.erase(item.find_last_not_of(" \t") + 1);
75+
if (!item.empty()) {
76+
result.push_back(item);
77+
}
78+
}
79+
return result;
80+
}
81+
5782
int32_t main(int32_t argc, char** argv) {
5883
gflags::ParseCommandLineFlags(&argc, &argv, true);
5984

@@ -62,10 +87,7 @@ int32_t main(int32_t argc, char** argv) {
6287
// and users can create their own DataLoaders to load from arbitrary sources.
6388
const char* model_path = FLAGS_model_path.c_str();
6489

65-
std::optional<std::string> data_path = std::nullopt;
66-
if (!FLAGS_data_path.empty()) {
67-
data_path = FLAGS_data_path.c_str();
68-
}
90+
std::vector<std::string> data_paths = parseStringList(FLAGS_data_paths);
6991

7092
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
7193

@@ -92,7 +114,7 @@ int32_t main(int32_t argc, char** argv) {
92114
#endif
93115
// create llama runner
94116
std::unique_ptr<::executorch::extension::llm::TextLLMRunner> runner =
95-
example::create_llama_runner(model_path, tokenizer_path, data_path);
117+
example::create_llama_runner(model_path, tokenizer_path, data_paths);
96118

97119
if (runner == nullptr) {
98120
ET_LOG(Error, "Failed to create llama runner");

0 commit comments

Comments
 (0)