Skip to content

Commit 5786f3f

Browse files
authored
Lora example (#52)
* lora * lora example * weight sharing * address comments
1 parent a726c0b commit 5786f3f

File tree

8 files changed

+371
-15
lines changed

8 files changed

+371
-15
lines changed

program-data-separation/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,7 @@ To enable LoRA, we generate:
2727

2828
Multiple LoRA-adapted PTE files can share the same foundation weights and adding a model adapted to a new task incurs minimal binary size and runtime memory overhead.
2929

30+
Please take a look at [program-data-separation/cpp/lora_example](lora_example/) for a demo of the program-data separation APIs with LoRA. This example generates and runs a LoRA and a non-LoRA model that share foundation weights. At runtime, we see that memory usage does not double.
31+
3032
### Requirements
3133
LoRA is currently supported on executorch main. [Please install ExecuTorch pip package from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html#install-executorch-pip-package-from-source), until executorch==1.0 is released.

program-data-separation/cpp/CMakeLists.txt

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,57 @@ option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
1414
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
1515
option(EXECUTORCH_BUILD_XNNPACK "" ON)
1616

17-
# Add ExecuTorch subdirectory
17+
# Dependencies required for llm runner in lora demo.
18+
if(EXECUTORCH_BUILD_LORA_DEMO)
19+
option(EXECUTORCH_BUILD_EXTENSION_LLM "" ON)
20+
option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER "" ON)
21+
option(EXECUTORCH_BUILD_KERNELS_LLM "" ON)
22+
option(EXECUTORCH_BUILD_KERNELS_LLM_AOT "" ON)
23+
endif()
24+
25+
# Add ExecuTorch subdirectory, after setting options.
1826
add_subdirectory("executorch")
1927

20-
set(DEMO_SOURCES linear_example/main.cpp)
28+
set(LINK_LIBS executorch
29+
executorch::extensions
30+
xnnpack_backend
31+
# NOTE: xnnpack_backend has to go before
32+
# kernels otherwise it doesn't get registered.
33+
executorch::kernels
34+
gflags
35+
)
36+
37+
# Add sources and dependencies.
38+
set(DEMO_SOURCES "")
39+
if(EXECUTORCH_BUILD_LINEAR_DEMO)
40+
list(APPEND DEMO_SOURCES "linear_example/main.cpp")
41+
endif()
42+
if(EXECUTORCH_BUILD_LORA_DEMO)
43+
list(APPEND DEMO_SOURCES "lora_example/main.cpp")
44+
endif()
2145

2246
# Create executable
2347
add_executable(executorch_program_data_separation ${DEMO_SOURCES})
2448

25-
# Include directories
26-
target_include_directories(executorch_program_data_separation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
27-
2849
# Link libraries
2950
target_link_libraries(
3051
executorch_program_data_separation
31-
PRIVATE executorch
32-
extension_module_static
33-
extension_flat_tensor
34-
extension_tensor
35-
xnnpack_backend
36-
portable_ops_lib
37-
portable_kernels
38-
gflags
52+
PRIVATE ${LINK_LIBS}
3953
)
4054

55+
# Include directories for lora demo.
56+
if(EXECUTORCH_BUILD_LORA_DEMO)
57+
# Include directories
58+
target_include_directories(executorch_program_data_separation PRIVATE
59+
${CMAKE_CURRENT_SOURCE_DIR}
60+
${CMAKE_CURRENT_SOURCE_DIR}/executorch/extension/llm/tokenizers/include
61+
)
62+
target_link_libraries(
63+
executorch_program_data_separation
64+
PUBLIC tokenizers::tokenizers
65+
)
66+
endif()
67+
4168
# Set output directory
4269
set_target_properties(executorch_program_data_separation
4370
PROPERTIES
Submodule executorch updated 328 files

program-data-separation/cpp/linear_example/build_example.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mkdir -p build
77
cd build
88

99
# Configure CMake
10-
cmake -DCMAKE_BUILD_TYPE=Release ../..
10+
cmake -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_BUILD_LINEAR_DEMO=True ../..
1111

1212
# Build the project
1313
cmake --build . -j$(nproc)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# ExecuTorch LoRA Demo
2+
3+
This directory contains the C++ code for the LoRA demo. This demo showcases how to export and run models that share the same architecture without inflating binary file size or runtime memory.
4+
5+
Specifically, this demo walks through exporting and running a LoRA and non-LoRA llama model without duplication of shared foundation weights on disk or in memory.
6+
7+
1. Exporting LoRA and non-LoRA llama models, lowered to XNNPACK, with weights in a separate file.
8+
2. Loading and running models with weights in a separate file.
9+
3. Runtime weight sharing via XNNPACK.
10+
11+
## Size savings.
12+
13+
Size results will vary depending on the model, quantization and LoRA config. For this demo, we save ~5GB of disk space by storing weights in a separate, sharable file and ~5GB runtime memory by sharing weights at runtime through the XNNPACK weight cache. Detailed results are below.
14+
15+
### XNNPACK weight sharing.
16+
17+
The XNNPACK backend is a singleton. Weight sharing is implemented via the XNNPACK weight cache. At delegate init time, XNNPACK checks the weight cache for the weights it needs. If they don't exist, XNNPACK will fetch weights from the NamedDataMap (the API that exposes weights in a PTD file), pack them, store them in the weight cache and free the original. This means we won't keep around multiple copies of the same weights.
18+
19+
## Virtual environment setup.
20+
Create and activate a Python virtual environment:
21+
```bash
22+
python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip
23+
```
24+
Or alternatively, [install conda on your machine](https://conda.io/projects/conda/en/latest/user-guide/install/index.html)
25+
```bash
26+
conda create -yn executorch-ptd python=3.10.0 && conda activate executorch-ptd
27+
```
28+
29+
Install dependencies:
30+
LoRA isn't available in the 0.7.0 release of ExecuTorch. Instead, please install from source until ExecuTorch 1.0 is released.
31+
32+
[Install ExecuTorch pip package from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html#install-executorch-pip-package-from-source).
33+
34+
Currently, the LoRA changes aren't in nightlies. Once they are in, you can also install from the nightly build.
35+
```
36+
pip install executorch==0.8.0.devYYYYMMDD --extra-index-url https://download.pytorch.org/whl/nightly/cpu
37+
```
38+
39+
## Export the model/s.
40+
Change into the program-data-separation directory and create a directory to hold exported artifacts.
41+
```bash
42+
cd ~/executorch-examples/program-data-separation
43+
mkdir models
44+
```
45+
46+
Export models into the `models` directory. The first command will generated undelegated model/data files, and the second will generate XNNPACK-delegated model/data files.
47+
```bash
48+
sh export_lora.sh
49+
```
50+
Expect the files:
51+
- llama_3_2_1B.pte
52+
- llama_3_2_1B.ptd
53+
- llama_3_2_1B_lora.pte
54+
- foundation_weights.ptd
55+
- tokenizer.model
56+
57+
llama_3_2_1B.ptd and foundation_weights.ptd contain the same contents, and you can remove llama_3_2_1B.ptd.
58+
tokenizer.model is copied from the temp directory where we downloaded the HF artifacts. It will be used at runtime.
59+
60+
Note:
61+
- PTE: contains the program execution logic.
62+
- PTD: contains the constant tensors used by the PTE. This format is similar to safetensors, but relying on flatbuffer instead of json for serde.
63+
64+
Sample file sizes:
65+
```
66+
-rw-r--r-- 1 lfq users 4943000480 Aug 11 15:55 foundation.ptd
67+
-rw-r--r-- 1 lfq users 1078636416 Aug 11 15:55 llama_3_2_1B_lora.pte
68+
-rw-r--r-- 1 lfq users 1051324736 Aug 11 15:53 llama_3_2_1B.pte
69+
```
70+
71+
Notice the lora - llama file size difference is about 27.3MB. This will change depending on the LoRA config. This demo is using the config from https://huggingface.co/lucylq/llama3_1B_lora/blob/main/adapter_config.json
72+
```
73+
{"r": 64, "lora_alpha": 128, "target_modules": ["q_proj", "v_proj", "o_proj"], "peft_type": "LORA", "base_model_name_or_path": "meta-llama/Llama-3.2-1B-Instruct"}
74+
```
75+
76+
## Install runtime dependencies.
77+
The ExecuTorch repository is configured as a git submodule at `~/executorch-examples/program-data-separation/cpp/executorch`. To initialize it:
78+
```bash
79+
cd ~/executorch-examples/
80+
git submodule sync
81+
git submodule update --init --recursive
82+
```
83+
Install dev requirements for ExecuTorch:
84+
85+
```bash
86+
cd ~/executorch-examples/program-data-separation/cpp/executorch
87+
pip install -r requirements-dev.txt
88+
```
89+
90+
## Build the runtime.
91+
Install some dependencies:
92+
```bash
93+
cd ~/executorch-examples/program-data-separation/cpp/executorch
94+
sh examples/models/llama/install_requirements.sh
95+
```
96+
97+
Build the executable:
98+
```bash
99+
cd ~/executorch-examples/program-data-separation/cpp/lora_example
100+
sh build_example.sh
101+
```
102+
103+
## Run the executable.
104+
```bash
105+
cd ~/executorch-examples/program-data-separation/cpp/lora_example
106+
107+
./build/bin/executorch_program_data_separation --lora_model_path=../../llama_3_2_1B_lora.pte --llama_model_path=../../llama_3_2_1B.pte --tokenizer_path=../../tokenizer.model --foundation_weights_path=../../foundation.ptd
108+
```
109+
110+
You should see some logs showing the Resident Set Size (RSS) at various points of the execution. Some sample logs may look like this:
111+
112+
```
113+
Generating with llama...
114+
RSS after loading model: 7886.125000 MiB
115+
RSS after prompt prefill: 7886.125000 MiB
116+
RSS after finishing text generation: 7886.125000 MiB
117+
118+
Generating with lora...
119+
RSS after loading model: 7933.523438 MiB
120+
RSS after prompt prefill: 7933.523438 MiB
121+
RSS after finishing text generation: 7933.523438 MiB
122+
```
123+
Notice the memory increase of ~47 MiB from running llama model to running lora model. You can see the difference without weight-sharing by removing the flag `-DEXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE=True` from `build_example.sh`.
124+
125+
## Clean up.
126+
```bash
127+
rm -rf build
128+
cd ~/executorch-examples/program-data-separation
129+
rm -rf *.pte *.ptd tokenizer.model
130+
```
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
set -e
3+
4+
# Clean and create build directory if it doesn't exist
5+
rm -rf build
6+
mkdir -p build
7+
cd build
8+
9+
# Configure CMake
10+
cmake -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_BUILD_LORA_DEMO=True -DEXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE=True ../..
11+
12+
# Build the project
13+
cmake --build . -j$(nproc)
14+
15+
echo "Build complete! Executable located at: ./build/bin/executorch_program_data_separation"
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
8+
*/
9+
10+
#include <memory>
11+
#include <string>
12+
#include <vector>
13+
14+
#include <gflags/gflags.h>
15+
16+
#include <executorch/extension/llm/runner/llm_runner_helper.h>
17+
#include <executorch/extension/llm/runner/stats.h>
18+
#include <executorch/extension/llm/runner/text_llm_runner.h>
19+
#include <executorch/extension/llm/runner/text_prefiller.h>
20+
#include <executorch/extension/llm/runner/text_token_generator.h>
21+
22+
#if defined(ET_USE_THREADPOOL)
23+
#include <executorch/extension/threadpool/cpuinfo_utils.h>
24+
#include <executorch/extension/threadpool/threadpool.h>
25+
#endif
26+
27+
DEFINE_string(lora_model_path, "llama_3_2_1B_lora.pte",
28+
"LoRA model serialized in flatbuffer format.");
29+
DEFINE_string(llama_model_path, "llama_3_2_1B.pte",
30+
"Model serialized in flatbuffer format.");
31+
DEFINE_string(foundation_weights_path, "foundation.ptd",
32+
"Foundation weights serialized in flatbuffer format.");
33+
34+
DEFINE_string(tokenizer_path, "tokenizer.model", "Tokenizer stuff.");
35+
36+
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
37+
38+
DEFINE_double(temperature, 0,
39+
"Temperature; Default is 0. 0 = greedy argmax sampling "
40+
"(deterministic). Lower temperature = more deterministic");
41+
42+
DEFINE_int32(
43+
seq_len, 128,
44+
"Total number of tokens to generate (prompt + output). Defaults to "
45+
"max_seq_len. If the number of input tokens + seq_len > max_seq_len, the "
46+
"output will be truncated to max_seq_len tokens.");
47+
48+
using executorch::extension::Module;
49+
using executorch::runtime::Error;
50+
namespace llm = executorch::extension::llm;
51+
52+
namespace {
53+
static constexpr int32_t kSpecialTokensSize = 256;
54+
static inline std::unique_ptr<std::vector<std::string>>
55+
_get_default_special_tokens() {
56+
auto special_tokens =
57+
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
58+
"<|begin_of_text|>", "<|end_of_text|>",
59+
"<|reserved_special_token_0|>", "<|reserved_special_token_1|>",
60+
"<|finetune_right_pad_id|>", "<|step_id|>", "<|start_header_id|>",
61+
"<|end_header_id|>", "<|eom_id|>", "<|eot_id|>", "<|python_tag|>"});
62+
// pad the rest of the special tokens with reserved tokens
63+
ssize_t reserved_special_token_num = 2;
64+
while (special_tokens->size() < kSpecialTokensSize) {
65+
special_tokens->emplace_back("<|reserved_special_token_" +
66+
std::to_string(reserved_special_token_num++) +
67+
"|>");
68+
}
69+
return special_tokens;
70+
}
71+
} // namespace
72+
73+
int main(int argc, char *argv[]) {
74+
ET_LOG(Info, "Running program-data separation lora example...");
75+
76+
gflags::ParseCommandLineFlags(&argc, &argv, true);
77+
78+
const char *lora_model_path = FLAGS_lora_model_path.c_str();
79+
const char *llama_model_path = FLAGS_llama_model_path.c_str();
80+
const char *foundation_weights_path = FLAGS_foundation_weights_path.c_str();
81+
82+
const char *tokenizer_path = FLAGS_tokenizer_path.c_str();
83+
const char *prompt = FLAGS_prompt.c_str();
84+
float temperature = FLAGS_temperature;
85+
int32_t seq_len = 128;
86+
int32_t cpu_threads = -1;
87+
88+
// Create tokenizers.
89+
std::unique_ptr<tokenizers::Tokenizer> tokenizer1 =
90+
llm::load_tokenizer(tokenizer_path, _get_default_special_tokens());
91+
std::unique_ptr<tokenizers::Tokenizer> tokenizer2 =
92+
llm::load_tokenizer(tokenizer_path, _get_default_special_tokens());
93+
94+
if (tokenizer1 == nullptr || tokenizer2 == nullptr) {
95+
ET_LOG(Info,
96+
"Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c "
97+
"tokenizer, make sure the artifact is one of these types",
98+
tokenizer_path);
99+
return 1;
100+
}
101+
102+
// Create runners.
103+
std::unique_ptr<llm::TextLLMRunner> llama_runner =
104+
llm::create_text_llm_runner(llama_model_path, std::move(tokenizer1),
105+
foundation_weights_path, temperature);
106+
std::unique_ptr<llm::TextLLMRunner> lora_runner =
107+
llm::create_text_llm_runner(lora_model_path, std::move(tokenizer2),
108+
foundation_weights_path, temperature);
109+
110+
// Generate.
111+
llm::GenerationConfig config{.seq_len = seq_len, .temperature = temperature};
112+
113+
ET_LOG(Info, "Generating with llama...");
114+
auto error = llama_runner->generate(prompt, config);
115+
if (error != Error::Ok) {
116+
ET_LOG(Error, "Failed to generate with llama_runner, error code %zu.",
117+
error);
118+
return 1;
119+
}
120+
121+
error = lora_runner->generate(prompt, config);
122+
if (error != Error::Ok) {
123+
ET_LOG(Error, "Failed to generate with lora_runner, error code %zu.",
124+
error);
125+
return 1;
126+
}
127+
128+
return 0;
129+
}

0 commit comments

Comments
 (0)