diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 2ba2fdf9941..19ed9f88339 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -37,6 +37,21 @@ std::unique_ptr create_llama_runner( const std::string& tokenizer_path, std::optional data_path, float temperature) { + if (data_path.has_value()) { + std::vector data_files; + data_files.push_back(data_path.value()); + return create_llama_runner( + model_path, tokenizer_path, std::move(data_files), temperature); + } + return create_llama_runner( + model_path, tokenizer_path, std::vector(), temperature); +} + +std::unique_ptr create_llama_runner( + const std::string& model_path, + const std::string& tokenizer_path, + std::vector data_files, + float temperature) { ET_LOG( Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", @@ -55,7 +70,7 @@ std::unique_ptr create_llama_runner( return nullptr; } return llm::create_text_llm_runner( - model_path, std::move(tokenizer), data_path); + model_path, std::move(tokenizer), data_files); } } // namespace example diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index f07cd4e8ee8..728ae57efa8 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -11,12 +11,9 @@ #pragma once -#include -#include #include #include #include -#include #include #include @@ -30,7 +27,13 @@ namespace llm = ::executorch::extension::llm; std::unique_ptr create_llama_runner( const std::string& model_path, const std::string& tokenizer_path, - std::optional data_path = std::nullopt, + std::optional data_path, + float temperature = -1.0f); + +std::unique_ptr create_llama_runner( + const std::string& model_path, + const std::string& tokenizer_path, + std::vector data_files = {}, float temperature = -1.0f); std::unique_ptr load_llama_tokenizer( diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index f12de5f1d87..d1e4ff2ce45 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -183,6 +183,24 @@ std::unique_ptr create_text_llm_runner( std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::optional data_path, float temperature) { + if (data_path.has_value()) { + std::vector data_files; + data_files.push_back(data_path.value()); + return create_text_llm_runner( + model_path, std::move(tokenizer), std::move(data_files), temperature); + } + return create_text_llm_runner( + model_path, + std::move(tokenizer), + std::vector(), + temperature); +} + +std::unique_ptr create_text_llm_runner( + const std::string& model_path, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::vector data_files, + float temperature) { // Sanity check tokenizer if (!tokenizer || !tokenizer->is_loaded()) { ET_LOG(Error, "Tokenizer is null or not loaded"); @@ -191,9 +209,9 @@ std::unique_ptr create_text_llm_runner( // Create the Module std::unique_ptr module; - if (data_path.has_value()) { + if (data_files.size() > 0) { module = std::make_unique( - model_path, data_path.value(), Module::LoadMode::File); + model_path, data_files, Module::LoadMode::File); } else { module = std::make_unique(model_path, Module::LoadMode::File); } diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index 191ea3ab090..5c109581e19 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -101,7 +101,28 @@ ET_EXPERIMENTAL std::unordered_set get_eos_ids( ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, - std::optional data_path = std::nullopt, + std::optional data_path, + float temperature = -1.0f); + +/** + * @brief Creates a TextLLMRunner instance with dependency injection + * + * This factory function creates and initializes a TextLLMRunner with all + * necessary components for text generation using the specified model and + * tokenizer. + * + * @param model_path Path to the model file + * @param tokenizer Initialized tokenizer instance + * @param data_files Vector of paths to additional data required by the model + * @param temperature Optional temperature parameter for controlling randomness + * (deprecated) + * @return std::unique_ptr Initialized TextLLMRunner instance, or + * nullptr on failure + */ +ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( + const std::string& model_path, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::vector data_files = {}, float temperature = -1.0f); /**