Skip to content

Commit 325bbc0

Browse files
authored
Extract common helper and constants into sepearate files for reusability
Differential Revision: D78997240 Pull Request resolved: #13109
1 parent 21bde13 commit 325bbc0

File tree

7 files changed

+366
-232
lines changed

7 files changed

+366
-232
lines changed

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010

1111
#pragma once
1212

13+
#include <executorch/extension/llm/runner/constants.h>
1314
#include <executorch/extension/llm/runner/image_prefiller.h>
1415
#include <executorch/extension/tensor/tensor.h>
1516

1617
namespace example {
1718

19+
using executorch::extension::llm::kImageEncoderMethod;
20+
using executorch::extension::llm::kTextModelMethod;
21+
1822
class ET_EXPERIMENTAL LlavaImagePrefiller {
1923
public:
2024
explicit LlavaImagePrefiller(::executorch::extension::Module* module)
@@ -96,9 +100,6 @@ class ET_EXPERIMENTAL LlavaImagePrefiller {
96100
return methods_loaded;
97101
}
98102

99-
inline static constexpr auto kImageEncoderMethod = "image_encoder";
100-
inline static constexpr auto kTextModelMethod = "text_model";
101-
102103
private:
103104
::executorch::extension::Module* module_;
104105
};

extension/llm/runner/constants.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#pragma once
9+
// constants for LLM runtime
10+
namespace executorch::extension::llm {
11+
12+
// Runtime metadata key constants
13+
inline constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
14+
inline constexpr auto kBosId = "get_bos_id";
15+
inline constexpr auto kEosIds = "get_eos_ids";
16+
inline constexpr auto kMaxSeqLen = "get_max_seq_len";
17+
inline constexpr auto kMaxContextLen = "get_max_context_len";
18+
inline constexpr auto kVocabSize = "get_vocab_size";
19+
inline constexpr auto kUseKVCache = "use_kv_cache";
20+
inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
21+
22+
// Multimodal method name conventions
23+
inline constexpr auto kImageEncoderMethod = "image_encoder";
24+
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
25+
inline constexpr auto kTextModelMethod = "text_model";
26+
27+
} // namespace executorch::extension::llm
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Implementation of helper utilities for creating and configuring LLM runners
10+
11+
#include <executorch/extension/llm/runner/llm_runner_helper.h>
12+
#include <executorch/extension/llm/runner/stats.h>
13+
#include <executorch/extension/llm/runner/text_llm_runner.h>
14+
#include <executorch/extension/llm/runner/text_prefiller.h>
15+
#include <executorch/extension/llm/runner/text_token_generator.h>
16+
#include <executorch/runtime/platform/runtime.h>
17+
#include <pytorch/tokenizers/hf_tokenizer.h>
18+
#include <pytorch/tokenizers/llama2c_tokenizer.h>
19+
#include <pytorch/tokenizers/sentencepiece.h>
20+
#include <pytorch/tokenizers/tiktoken.h>
21+
22+
namespace executorch {
23+
namespace extension {
24+
namespace llm {
25+
26+
using ::executorch::extension::Module;
27+
using ::executorch::runtime::Error;
28+
29+
std::unique_ptr<tokenizers::Tokenizer> load_tokenizer(
30+
const std::string& tokenizer_path,
31+
std::unique_ptr<std::vector<std::string>> special_tokens,
32+
std::optional<std::string> pattern,
33+
size_t bos_token_index,
34+
size_t eos_token_index) {
35+
runtime::runtime_init();
36+
auto json_tokenizer = std::make_unique<tokenizers::HFTokenizer>();
37+
if (json_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
38+
ET_LOG(Info, "Loaded json tokenizer");
39+
return json_tokenizer;
40+
}
41+
std::unique_ptr<::tokenizers::Tiktoken> tiktoken_tokenizer;
42+
if (special_tokens != nullptr && !pattern.has_value()) {
43+
tiktoken_tokenizer = std::make_unique<::tokenizers::Tiktoken>(
44+
std::move(special_tokens), bos_token_index, eos_token_index);
45+
} else if (special_tokens != nullptr && pattern.has_value()) {
46+
tiktoken_tokenizer = std::make_unique<::tokenizers::Tiktoken>(
47+
pattern.value(),
48+
std::move(special_tokens),
49+
bos_token_index,
50+
eos_token_index);
51+
} else {
52+
tiktoken_tokenizer = std::make_unique<::tokenizers::Tiktoken>();
53+
}
54+
if (tiktoken_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
55+
ET_LOG(Info, "Loaded TikToken tokenizer");
56+
return tiktoken_tokenizer;
57+
}
58+
59+
auto sp_tokenizer = std::make_unique<::tokenizers::SPTokenizer>();
60+
if (sp_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
61+
ET_LOG(Info, "Loaded Sentencepiece tokenizer");
62+
return sp_tokenizer;
63+
}
64+
65+
auto bpe_tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
66+
if (bpe_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
67+
ET_LOG(Info, "Loaded BPE tokenizer");
68+
return bpe_tokenizer;
69+
}
70+
71+
return nullptr;
72+
}
73+
74+
std::unordered_map<std::string, int64_t> get_llm_metadata(
75+
tokenizers::Tokenizer* tokenizer,
76+
Module* module) {
77+
// Initialize metadata with default values
78+
std::unordered_map<std::string, int64_t> metadata({
79+
{llm::kEnableDynamicShape, false},
80+
{llm::kMaxSeqLen, 128},
81+
{llm::kMaxContextLen, 128},
82+
{llm::kUseKVCache, true},
83+
{llm::kUseSDPAWithKVCache, false},
84+
});
85+
86+
// Read metadata from the model
87+
auto method_names_result = module->method_names();
88+
if (method_names_result.error() != Error::Ok) {
89+
ET_LOG(Error, "Failed reading method names");
90+
return metadata;
91+
}
92+
const auto& method_names = method_names_result.get();
93+
94+
for (auto& pair : metadata) {
95+
const auto& method_name = pair.first;
96+
auto& value = pair.second;
97+
98+
if (method_names.count(method_name)) {
99+
auto get_result = module->get(method_name);
100+
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
101+
} else {
102+
ET_LOG(
103+
Info,
104+
"Method %s not found, using the default value %" PRId64,
105+
method_name.c_str(),
106+
value);
107+
}
108+
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
109+
}
110+
// Set tokenizer-related metadata
111+
metadata[llm::kBosId] = tokenizer->bos_tok();
112+
metadata[llm::kVocabSize] = tokenizer->vocab_size();
113+
return metadata;
114+
}
115+
116+
std::unordered_set<uint64_t> get_eos_ids(
117+
tokenizers::Tokenizer* tokenizer,
118+
Module* module) {
119+
std::unordered_set<uint64_t> eos_ids = {tokenizer->eos_tok()};
120+
// Get EOS IDs if available
121+
auto method_names_result = module->method_names();
122+
if (method_names_result.error() != Error::Ok) {
123+
ET_LOG(Error, "Failed reading method names");
124+
return eos_ids;
125+
}
126+
const auto& method_names = method_names_result.get();
127+
128+
if (method_names.count(llm::kEosIds)) {
129+
eos_ids.clear();
130+
auto execute_result = module->execute(llm::kEosIds);
131+
if (execute_result.error() != Error::Ok) {
132+
ET_LOG(Error, "Failed to execute %s", llm::kEosIds);
133+
return eos_ids;
134+
}
135+
for (const auto& eos_id : execute_result.get()) {
136+
auto value = eos_id.toScalar().to<int64_t>();
137+
eos_ids.emplace(value);
138+
ET_LOG(Info, "eos_id = %" PRId64, value);
139+
}
140+
}
141+
return eos_ids;
142+
}
143+
144+
std::unique_ptr<TextLLMRunner> create_text_llm_runner(
145+
const std::string& model_path,
146+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
147+
std::optional<const std::string> data_path,
148+
float temperature) {
149+
// Sanity check tokenizer
150+
if (!tokenizer || !tokenizer->is_loaded()) {
151+
ET_LOG(Error, "Tokenizer is null or not loaded");
152+
return nullptr;
153+
}
154+
155+
// Create the Module
156+
std::unique_ptr<Module> module;
157+
if (data_path.has_value()) {
158+
module = std::make_unique<Module>(
159+
model_path, data_path.value(), Module::LoadMode::File);
160+
} else {
161+
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
162+
}
163+
164+
// Get metadata from Module
165+
ET_LOG(Info, "Reading metadata from model");
166+
auto metadata = llm::get_llm_metadata(tokenizer.get(), module.get());
167+
168+
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
169+
llm::get_eos_ids(tokenizer.get(), module.get()));
170+
171+
// Create IOManager
172+
std::unique_ptr<IOManager> io_manager = std::make_unique<IOManager>();
173+
174+
// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
175+
// TextPrefiller and TextTokenGenerator
176+
auto text_decoder_runner =
177+
std::make_unique<TextDecoderRunner>(module.get(), io_manager.get());
178+
179+
// Create text_prefiller
180+
auto text_prefiller = std::make_unique<TextPrefiller>(
181+
text_decoder_runner.get(),
182+
metadata.at(kUseKVCache),
183+
metadata.at(kEnableDynamicShape),
184+
metadata.at(kMaxSeqLen));
185+
186+
// Create text_token_generator with stats
187+
auto stats = std::make_unique<Stats>();
188+
auto text_token_generator = std::make_unique<TextTokenGenerator>(
189+
tokenizer.get(),
190+
text_decoder_runner.get(),
191+
metadata.at(kUseKVCache),
192+
std::move(eos_ids),
193+
stats.get());
194+
195+
// Create and return the Runner instance
196+
return std::make_unique<TextLLMRunner>(
197+
std::move(metadata),
198+
std::move(tokenizer),
199+
std::move(module),
200+
std::move(text_decoder_runner),
201+
std::move(text_prefiller),
202+
std::move(io_manager),
203+
std::move(text_token_generator),
204+
std::move(stats),
205+
temperature);
206+
}
207+
208+
} // namespace llm
209+
} // namespace extension
210+
} // namespace executorch
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Helper utilities for creating and configuring LLM runners
10+
11+
#pragma once
12+
13+
#include <memory>
14+
#include <optional>
15+
#include <string>
16+
#include <unordered_map>
17+
#include <unordered_set>
18+
#include <vector>
19+
20+
#include <executorch/extension/llm/runner/constants.h>
21+
#include <executorch/extension/module/module.h>
22+
#include <executorch/runtime/platform/compiler.h>
23+
#include <pytorch/tokenizers/tokenizer.h>
24+
25+
namespace executorch {
26+
namespace extension {
27+
namespace llm {
28+
29+
// Forward declarations
30+
class TextLLMRunner;
31+
class MultimodalRunner;
32+
33+
/**
34+
* @brief Loads a tokenizer from the specified path
35+
*
36+
* This function creates and initializes a tokenizer from a file, with options
37+
* to customize special tokens and regex patterns. It tries different tokenizer
38+
* types in order: HF JSON, TikToken, SentencePiece, and BPE.
39+
*
40+
* @param tokenizer_path Path to the tokenizer file
41+
* @param special_tokens Optional list of special tokens to add to the tokenizer
42+
* @param pattern Optional regex pattern for tokenization
43+
* @param bos_token_index Index of the beginning-of-sequence token
44+
* @param eos_token_index Index of the end-of-sequence token
45+
* @return std::unique_ptr<tokenizers::Tokenizer> Initialized tokenizer
46+
* instance, or nullptr on failure
47+
*/
48+
ET_EXPERIMENTAL std::unique_ptr<tokenizers::Tokenizer> load_tokenizer(
49+
const std::string& tokenizer_path,
50+
std::unique_ptr<std::vector<std::string>> special_tokens = nullptr,
51+
std::optional<std::string> pattern = std::nullopt,
52+
size_t bos_token_index = 0,
53+
size_t eos_token_index = 1);
54+
55+
/**
56+
* @brief Gets LLM metadata from the model and tokenizer
57+
*
58+
* This function extracts metadata from the model such as vocabulary size,
59+
* context length, and other configuration parameters. It reads metadata
60+
* methods from the model and combines them with tokenizer information.
61+
*
62+
* @param tokenizer Initialized tokenizer instance
63+
* @param module The model module
64+
* @return std::unordered_map<std::string, int64_t> Metadata key-value pairs
65+
*/
66+
ET_EXPERIMENTAL std::unordered_map<std::string, int64_t> get_llm_metadata(
67+
tokenizers::Tokenizer* tokenizer,
68+
Module* module);
69+
70+
/**
71+
* @brief Gets EOS token IDs from the model and tokenizer
72+
*
73+
* This function extracts the end-of-sequence token IDs from the model.
74+
* It first tries to get EOS IDs from the model's metadata, falling back
75+
* to the tokenizer's default EOS token.
76+
*
77+
* @param tokenizer Initialized tokenizer instance
78+
* @param module The model module
79+
* @return std::unordered_set<uint64_t> Set of EOS token IDs
80+
*/
81+
ET_EXPERIMENTAL std::unordered_set<uint64_t> get_eos_ids(
82+
tokenizers::Tokenizer* tokenizer,
83+
Module* module);
84+
85+
/**
86+
* @brief Creates a TextLLMRunner instance with dependency injection
87+
*
88+
* This factory function creates and initializes a TextLLMRunner with all
89+
* necessary components for text generation using the specified model and
90+
* tokenizer.
91+
*
92+
* @param model_path Path to the model file
93+
* @param tokenizer Initialized tokenizer instance
94+
* @param data_path Optional path to additional data required by the model
95+
* @param temperature Optional temperature parameter for controlling randomness
96+
* (deprecated)
97+
* @return std::unique_ptr<TextLLMRunner> Initialized TextLLMRunner instance, or
98+
* nullptr on failure
99+
*/
100+
ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
101+
const std::string& model_path,
102+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
103+
std::optional<const std::string> data_path = std::nullopt,
104+
float temperature = -1.0f);
105+
106+
} // namespace llm
107+
} // namespace extension
108+
} // namespace executorch

0 commit comments

Comments
 (0)