diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh index 6337bbf76a2..70611206253 100644 --- a/.ci/scripts/test_llama_lora.sh +++ b/.ci/scripts/test_llama_lora.sh @@ -95,7 +95,9 @@ else fi # Export LoRA PTE, PTD file. -MODEL_SEPARATE="${MODEL_NAME}_separate" +MODEL_PROGRAM_ONLY="${MODEL_NAME}_program" +MODEL_LORA_WEIGHTS="lora_weights" +MODEL_FOUNDATION_WEIGHTS="foundation_weights" $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ base.params="${DOWNLOADED_PATH}/params.json" \ @@ -107,14 +109,15 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ model.dtype_override="fp32" \ backend.xnnpack.enabled=true \ backend.xnnpack.extended_ops=true \ - export.output_name="${MODEL_SEPARATE}.pte" \ - export.foundation_weights_file="${MODEL_SEPARATE}.ptd" + export.output_name="${MODEL_PROGRAM_ONLY}.pte" \ + export.foundation_weights_file="${MODEL_FOUNDATION_WEIGHTS}.ptd" \ + export.lora_weights_file="${MODEL_LORA_WEIGHTS}.ptd" # Run llama runner. NOW=$(date +"%H:%M:%S") echo "Starting to run llama runner at ${NOW}" # shellcheck source=/dev/null -cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_PROGRAM_ONLY}.pte --data_paths="${MODEL_FOUNDATION_WEIGHTS}.ptd,${MODEL_LORA_WEIGHTS}.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt NOW=$(date +"%H:%M:%S") echo "Finished at ${NOW}" diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index aa3b157c8da..7c2705f0a15 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1088,13 +1088,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config.backend.xnnpack.enabled = True if llm_config.backend.xnnpack.enabled: - if llm_config.export.foundation_weights_file is not None: + if ( + llm_config.export.foundation_weights_file is not None + or llm_config.export.lora_weights_file is not None + ): gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( llm_config.export.foundation_weights_file if "lora" not in x.name - else None + else llm_config.export.lora_weights_file ) - from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, external_constants_pass, diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 078d938ffde..0244a7f5661 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -8,6 +8,8 @@ */ #include +#include +#include #include @@ -21,7 +23,10 @@ DEFINE_string( "llama2.pte", "Model serialized in flatbuffer format."); -DEFINE_string(data_path, "", "Data file for the model."); +DEFINE_string( + data_paths, + "", + "Data files for the model. If multiple files are provided, they should be comma separated."); DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); @@ -54,6 +59,26 @@ DEFINE_int32( DEFINE_bool(warmup, false, "Whether to run a warmup run."); +// Helper function to parse comma-separated string lists +std::vector parseStringList(const std::string& input) { + std::vector result; + if (input.empty()) { + return result; + } + + std::stringstream ss(input); + std::string item; + while (std::getline(ss, item, ',')) { + // Trim whitespace + item.erase(0, item.find_first_not_of(" \t")); + item.erase(item.find_last_not_of(" \t") + 1); + if (!item.empty()) { + result.push_back(item); + } + } + return result; +} + int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -62,10 +87,7 @@ int32_t main(int32_t argc, char** argv) { // and users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); - std::optional data_path = std::nullopt; - if (!FLAGS_data_path.empty()) { - data_path = FLAGS_data_path.c_str(); - } + std::vector data_paths = parseStringList(FLAGS_data_paths); const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); @@ -92,7 +114,7 @@ int32_t main(int32_t argc, char** argv) { #endif // create llama runner std::unique_ptr<::executorch::extension::llm::TextLLMRunner> runner = - example::create_llama_runner(model_path, tokenizer_path, data_path); + example::create_llama_runner(model_path, tokenizer_path, data_paths); if (runner == nullptr) { ET_LOG(Error, "Failed to create llama runner"); diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index b13001c005b..f15aad9e000 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -215,9 +215,10 @@ class ExportConfig: so_library: Shared library to specify custom quantized operators. export_only: Whether to stop right after torch.export() and just save the exported .pt2 graph file. - foundation_weights_file: configure the foundation weights of a model - to be placed in a separate file, external to the PTE. Pass the - intended file name here. + foundation_weights_file: place the foundation weights of the model into + a separate file, external to the PTE. Pass the file name here. + lora_weights_file: place the lora weights of the model into a + separate file, external to the PTE. Pass the file name here. """ max_seq_length: int = 128 @@ -227,6 +228,7 @@ class ExportConfig: so_library: Optional[str] = None export_only: bool = False foundation_weights_file: Optional[str] = None + lora_weights_file: Optional[str] = None def __post_init__(self): if self.max_context_length < self.max_seq_length: @@ -572,6 +574,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.export.export_only = args.export_only if hasattr(args, "foundation_weights_file"): llm_config.export.foundation_weights_file = args.foundation_weights_file + if hasattr(args, "lora_weights_file"): + llm_config.export.lora_weights_file = args.lora_weights_file # QuantizationConfig if hasattr(args, "quantization_mode"):