Skip to content

Commit f2f06c1

Browse files
committed
weight sharing
1 parent 9ceef85 commit f2f06c1

File tree

4 files changed

+64
-28
lines changed

4 files changed

+64
-28
lines changed

program-data-separation/cpp/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ if(EXECUTORCH_BUILD_LINEAR_DEMO)
4141
endif()
4242
if(EXECUTORCH_BUILD_LORA_DEMO)
4343
list(APPEND DEMO_SOURCES "lora_example/main.cpp")
44-
add_subdirectory("executorch/examples/models/llama/runner")
45-
list(APPEND LINK_LIBS llama_runner)
4644
endif()
4745

4846
# Create executable

program-data-separation/cpp/lora_example/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ sh build_example.sh
7676
```
7777

7878
## Run the executable.
79-
```
79+
```bash
80+
cd ~/executorch-examples/program-data-separation/cpp/lora_example
81+
8082
./build/bin/executorch_program_data_separation --lora_model_path=../../llama_3_2_1B_lora.pte --llama_model_path=../../llama_3_2_1B.pte --tokenizer_path=../../tokenizer.model --data_path=../../foundation.ptd
8183
```
8284

program-data-separation/cpp/lora_example/build_example.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mkdir -p build
77
cd build
88

99
# Configure CMake
10-
cmake -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_BUILD_LORA_DEMO=True ../..
10+
cmake -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_BUILD_LORA_DEMO=True -DEXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE=True ../..
1111

1212
# Build the project
1313
cmake --build . -j$(nproc)

program-data-separation/cpp/lora_example/main.cpp

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@
66
* LICENSE file in the root directory of this source tree.
77
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
88
*/
9+
10+
#include <memory>
11+
#include <string>
12+
#include <vector>
13+
914
#include <gflags/gflags.h>
1015

11-
#include <executorch/examples/models/llama/runner/runner.h>
16+
#include <executorch/extension/llm/runner/llm_runner_helper.h>
17+
#include <executorch/extension/llm/runner/stats.h>
18+
#include <executorch/extension/llm/runner/text_llm_runner.h>
19+
#include <executorch/extension/llm/runner/text_prefiller.h>
20+
#include <executorch/extension/llm/runner/text_token_generator.h>
1221

1322
#if defined(ET_USE_THREADPOOL)
1423
#include <executorch/extension/threadpool/cpuinfo_utils.h>
@@ -36,7 +45,30 @@ DEFINE_int32(
3645
"max_seq_len. If the number of input tokens + seq_len > max_seq_len, the "
3746
"output will be truncated to max_seq_len tokens.");
3847

39-
using namespace ::executorch::extension;
48+
using executorch::extension::Module;
49+
using executorch::runtime::Error;
50+
namespace llm = executorch::extension::llm;
51+
52+
namespace {
53+
static constexpr int32_t kSpecialTokensSize = 256;
54+
static inline std::unique_ptr<std::vector<std::string>>
55+
_get_default_special_tokens() {
56+
auto special_tokens =
57+
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
58+
"<|begin_of_text|>", "<|end_of_text|>",
59+
"<|reserved_special_token_0|>", "<|reserved_special_token_1|>",
60+
"<|finetune_right_pad_id|>", "<|step_id|>", "<|start_header_id|>",
61+
"<|end_header_id|>", "<|eom_id|>", "<|eot_id|>", "<|python_tag|>"});
62+
// pad the rest of the special tokens with reserved tokens
63+
ssize_t reserved_special_token_num = 2;
64+
while (special_tokens->size() < kSpecialTokensSize) {
65+
special_tokens->emplace_back("<|reserved_special_token_" +
66+
std::to_string(reserved_special_token_num++) +
67+
"|>");
68+
}
69+
return special_tokens;
70+
}
71+
} // namespace
4072

4173
int main(int argc, char *argv[]) {
4274
ET_LOG(Info, "Running program-data separation lora example...");
@@ -53,37 +85,41 @@ int main(int argc, char *argv[]) {
5385
int32_t seq_len = 128;
5486
int32_t cpu_threads = -1;
5587

56-
// Create runner for lora model.
57-
std::unique_ptr<::executorch::extension::llm::TextLLMRunner> lora_runner =
58-
example::create_llama_runner(lora_model_path, tokenizer_path, data_path);
59-
if (lora_runner == nullptr) {
60-
ET_LOG(Error, "Failed to create lora_runner.");
88+
// Create tokenizers.
89+
std::unique_ptr<tokenizers::Tokenizer> tokenizer1 =
90+
llm::load_tokenizer(tokenizer_path, _get_default_special_tokens());
91+
std::unique_ptr<tokenizers::Tokenizer> tokenizer2 =
92+
llm::load_tokenizer(tokenizer_path, _get_default_special_tokens());
93+
94+
if (tokenizer1 == nullptr || tokenizer2 == nullptr) {
95+
ET_LOG(Info,
96+
"Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c "
97+
"tokenizer, make sure the artifact is one of these types",
98+
tokenizer_path);
6199
return 1;
62100
}
63101

64-
// create runner for llama model
65-
std::unique_ptr<::executorch::extension::llm::TextLLMRunner> llama_runner =
66-
example::create_llama_runner(llama_model_path, tokenizer_path, data_path);
67-
if (llama_runner == nullptr) {
68-
ET_LOG(Error, "Failed to create llama_runner.");
69-
return 1;
70-
}
102+
// Create runners.
103+
std::unique_ptr<llm::TextLLMRunner> llama_runner =
104+
llm::create_text_llm_runner(llama_model_path, std::move(tokenizer1),
105+
data_path, temperature);
106+
std::unique_ptr<llm::TextLLMRunner> lora_runner = llm::create_text_llm_runner(
107+
lora_model_path, std::move(tokenizer2), data_path, temperature);
71108

72-
// generate
73-
executorch::extension::llm::GenerationConfig config{
74-
.seq_len = seq_len, .temperature = temperature};
109+
// Generate.
110+
llm::GenerationConfig config{.seq_len = seq_len, .temperature = temperature};
75111

76-
auto error = lora_runner->generate(prompt, config);
77-
if (error != executorch::runtime::Error::Ok) {
78-
ET_LOG(Error, "Failed to generate with lora_runner, error code %zu.",
112+
ET_LOG(Info, "Generating with llama...");
113+
auto error = llama_runner->generate(prompt, config);
114+
if (error != Error::Ok) {
115+
ET_LOG(Error, "Failed to generate with llama_runner, error code %zu.",
79116
error);
80117
return 1;
81118
}
82119

83-
ET_LOG(Info, "Generating with llama...");
84-
error = llama_runner->generate(prompt, config);
85-
if (error != executorch::runtime::Error::Ok) {
86-
ET_LOG(Error, "Failed to generate with llama_runner, error code %zu.",
120+
error = lora_runner->generate(prompt, config);
121+
if (error != Error::Ok) {
122+
ET_LOG(Error, "Failed to generate with lora_runner, error code %zu.",
87123
error);
88124
return 1;
89125
}

0 commit comments

Comments
 (0)