Skip to content

Commit c44c541

Browse files
pytorchbotlucylq
andauthored
Runner support for multiple ptd files (#14758)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14159 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/111/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/111/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/111/orig Differential Revision: [D82072385](https://our.internmc.facebook.com/intern/diff/D82072385/) @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent 3557edf commit c44c541

File tree

4 files changed

+65
-8
lines changed

4 files changed

+65
-8
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3737
const std::string& tokenizer_path,
3838
std::optional<const std::string> data_path,
3939
float temperature) {
40+
if (data_path.has_value()) {
41+
std::vector<std::string> data_files;
42+
data_files.push_back(data_path.value());
43+
return create_llama_runner(
44+
model_path, tokenizer_path, std::move(data_files), temperature);
45+
}
46+
return create_llama_runner(
47+
model_path, tokenizer_path, std::vector<std::string>(), temperature);
48+
}
49+
50+
std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
51+
const std::string& model_path,
52+
const std::string& tokenizer_path,
53+
std::vector<std::string> data_files,
54+
float temperature) {
4055
ET_LOG(
4156
Info,
4257
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
@@ -55,7 +70,7 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
5570
return nullptr;
5671
}
5772
return llm::create_text_llm_runner(
58-
model_path, std::move(tokenizer), data_path);
73+
model_path, std::move(tokenizer), data_files);
5974
}
6075

6176
} // namespace example

examples/models/llama/runner/runner.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111

1212
#pragma once
1313

14-
#include <cstdint>
15-
#include <functional>
1614
#include <memory>
1715
#include <optional>
1816
#include <string>
19-
#include <unordered_map>
2017

2118
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
2219
#include <executorch/extension/llm/runner/irunner.h>
@@ -30,7 +27,13 @@ namespace llm = ::executorch::extension::llm;
3027
std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3128
const std::string& model_path,
3229
const std::string& tokenizer_path,
33-
std::optional<const std::string> data_path = std::nullopt,
30+
std::optional<const std::string> data_path,
31+
float temperature = -1.0f);
32+
33+
std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
34+
const std::string& model_path,
35+
const std::string& tokenizer_path,
36+
std::vector<std::string> data_files = {},
3437
float temperature = -1.0f);
3538

3639
std::unique_ptr<tokenizers::Tokenizer> load_llama_tokenizer(

extension/llm/runner/llm_runner_helper.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,24 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
183183
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
184184
std::optional<const std::string> data_path,
185185
float temperature) {
186+
if (data_path.has_value()) {
187+
std::vector<std::string> data_files;
188+
data_files.push_back(data_path.value());
189+
return create_text_llm_runner(
190+
model_path, std::move(tokenizer), std::move(data_files), temperature);
191+
}
192+
return create_text_llm_runner(
193+
model_path,
194+
std::move(tokenizer),
195+
std::vector<std::string>(),
196+
temperature);
197+
}
198+
199+
std::unique_ptr<TextLLMRunner> create_text_llm_runner(
200+
const std::string& model_path,
201+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
202+
std::vector<std::string> data_files,
203+
float temperature) {
186204
// Sanity check tokenizer
187205
if (!tokenizer || !tokenizer->is_loaded()) {
188206
ET_LOG(Error, "Tokenizer is null or not loaded");
@@ -191,9 +209,9 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
191209

192210
// Create the Module
193211
std::unique_ptr<Module> module;
194-
if (data_path.has_value()) {
212+
if (data_files.size() > 0) {
195213
module = std::make_unique<Module>(
196-
model_path, data_path.value(), Module::LoadMode::File);
214+
model_path, data_files, Module::LoadMode::File);
197215
} else {
198216
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
199217
}

extension/llm/runner/llm_runner_helper.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,28 @@ ET_EXPERIMENTAL std::unordered_set<uint64_t> get_eos_ids(
101101
ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
102102
const std::string& model_path,
103103
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
104-
std::optional<const std::string> data_path = std::nullopt,
104+
std::optional<const std::string> data_path,
105+
float temperature = -1.0f);
106+
107+
/**
108+
* @brief Creates a TextLLMRunner instance with dependency injection
109+
*
110+
* This factory function creates and initializes a TextLLMRunner with all
111+
* necessary components for text generation using the specified model and
112+
* tokenizer.
113+
*
114+
* @param model_path Path to the model file
115+
* @param tokenizer Initialized tokenizer instance
116+
* @param data_files Vector of paths to additional data required by the model
117+
* @param temperature Optional temperature parameter for controlling randomness
118+
* (deprecated)
119+
* @return std::unique_ptr<TextLLMRunner> Initialized TextLLMRunner instance, or
120+
* nullptr on failure
121+
*/
122+
ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
123+
const std::string& model_path,
124+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
125+
std::vector<std::string> data_files = {},
105126
float temperature = -1.0f);
106127

107128
/**

0 commit comments

Comments
 (0)