6
6
* LICENSE file in the root directory of this source tree.
7
7
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
8
8
*/
9
+
10
+ #include < memory>
11
+ #include < string>
12
+ #include < vector>
13
+
9
14
#include < gflags/gflags.h>
10
15
11
- #include < executorch/examples/models/llama/runner/runner.h>
16
+ #include < executorch/extension/llm/runner/llm_runner_helper.h>
17
+ #include < executorch/extension/llm/runner/stats.h>
18
+ #include < executorch/extension/llm/runner/text_llm_runner.h>
19
+ #include < executorch/extension/llm/runner/text_prefiller.h>
20
+ #include < executorch/extension/llm/runner/text_token_generator.h>
12
21
13
22
#if defined(ET_USE_THREADPOOL)
14
23
#include < executorch/extension/threadpool/cpuinfo_utils.h>
@@ -36,7 +45,30 @@ DEFINE_int32(
36
45
" max_seq_len. If the number of input tokens + seq_len > max_seq_len, the "
37
46
" output will be truncated to max_seq_len tokens." );
38
47
39
- using namespace ::executorch::extension;
48
+ using executorch::extension::Module;
49
+ using executorch::runtime::Error;
50
+ namespace llm = executorch::extension::llm;
51
+
52
+ namespace {
53
+ static constexpr int32_t kSpecialTokensSize = 256 ;
54
+ static inline std::unique_ptr<std::vector<std::string>>
55
+ _get_default_special_tokens () {
56
+ auto special_tokens =
57
+ std::make_unique<std::vector<std::string>>(std::vector<std::string>{
58
+ " <|begin_of_text|>" , " <|end_of_text|>" ,
59
+ " <|reserved_special_token_0|>" , " <|reserved_special_token_1|>" ,
60
+ " <|finetune_right_pad_id|>" , " <|step_id|>" , " <|start_header_id|>" ,
61
+ " <|end_header_id|>" , " <|eom_id|>" , " <|eot_id|>" , " <|python_tag|>" });
62
+ // pad the rest of the special tokens with reserved tokens
63
+ ssize_t reserved_special_token_num = 2 ;
64
+ while (special_tokens->size () < kSpecialTokensSize ) {
65
+ special_tokens->emplace_back (" <|reserved_special_token_" +
66
+ std::to_string (reserved_special_token_num++) +
67
+ " |>" );
68
+ }
69
+ return special_tokens;
70
+ }
71
+ } // namespace
40
72
41
73
int main (int argc, char *argv[]) {
42
74
ET_LOG (Info, " Running program-data separation lora example..." );
@@ -53,37 +85,41 @@ int main(int argc, char *argv[]) {
53
85
int32_t seq_len = 128 ;
54
86
int32_t cpu_threads = -1 ;
55
87
56
- // Create runner for lora model.
57
- std::unique_ptr<::executorch::extension::llm::TextLLMRunner> lora_runner =
58
- example::create_llama_runner (lora_model_path, tokenizer_path, data_path);
59
- if (lora_runner == nullptr ) {
60
- ET_LOG (Error, " Failed to create lora_runner." );
88
+ // Create tokenizers.
89
+ std::unique_ptr<tokenizers::Tokenizer> tokenizer1 =
90
+ llm::load_tokenizer (tokenizer_path, _get_default_special_tokens ());
91
+ std::unique_ptr<tokenizers::Tokenizer> tokenizer2 =
92
+ llm::load_tokenizer (tokenizer_path, _get_default_special_tokens ());
93
+
94
+ if (tokenizer1 == nullptr || tokenizer2 == nullptr ) {
95
+ ET_LOG (Info,
96
+ " Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c "
97
+ " tokenizer, make sure the artifact is one of these types" ,
98
+ tokenizer_path);
61
99
return 1 ;
62
100
}
63
101
64
- // create runner for llama model
65
- std::unique_ptr<::executorch::extension::llm::TextLLMRunner> llama_runner =
66
- example::create_llama_runner (llama_model_path, tokenizer_path, data_path);
67
- if (llama_runner == nullptr ) {
68
- ET_LOG (Error, " Failed to create llama_runner." );
69
- return 1 ;
70
- }
102
+ // Create runners.
103
+ std::unique_ptr<llm::TextLLMRunner> llama_runner =
104
+ llm::create_text_llm_runner (llama_model_path, std::move (tokenizer1),
105
+ data_path, temperature);
106
+ std::unique_ptr<llm::TextLLMRunner> lora_runner = llm::create_text_llm_runner (
107
+ lora_model_path, std::move (tokenizer2), data_path, temperature);
71
108
72
- // generate
73
- executorch::extension::llm::GenerationConfig config{
74
- .seq_len = seq_len, .temperature = temperature};
109
+ // Generate.
110
+ llm::GenerationConfig config{.seq_len = seq_len, .temperature = temperature};
75
111
76
- auto error = lora_runner->generate (prompt, config);
77
- if (error != executorch::runtime::Error::Ok) {
78
- ET_LOG (Error, " Failed to generate with lora_runner, error code %zu." ,
112
+ ET_LOG (Info, " Generating with llama..." );
113
+ auto error = llama_runner->generate (prompt, config);
114
+ if (error != Error::Ok) {
115
+ ET_LOG (Error, " Failed to generate with llama_runner, error code %zu." ,
79
116
error);
80
117
return 1 ;
81
118
}
82
119
83
- ET_LOG (Info, " Generating with llama..." );
84
- error = llama_runner->generate (prompt, config);
85
- if (error != executorch::runtime::Error::Ok) {
86
- ET_LOG (Error, " Failed to generate with llama_runner, error code %zu." ,
120
+ error = lora_runner->generate (prompt, config);
121
+ if (error != Error::Ok) {
122
+ ET_LOG (Error, " Failed to generate with lora_runner, error code %zu." ,
87
123
error);
88
124
return 1 ;
89
125
}
0 commit comments