Skip to content

Commit 8e91b83

Browse files
lucylqkirklandsign
andauthored
Finetuned lora example (#99)
* Add Model type toString() (#85) * finetune lora example * add lora demo video --------- Co-authored-by: Hansong Zhang <[email protected]>
1 parent 35f98e2 commit 8e91b83

File tree

3 files changed

+169
-76
lines changed

3 files changed

+169
-76
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
[submodule "program-data-separation/cpp/executorch"]
77
path = program-data-separation/cpp/executorch
88
url = https://github.com/pytorch/executorch.git
9-
branch = release/0.7
9+
branch = main

program-data-separation/cpp/lora_example/README.md

Lines changed: 165 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,102 +3,151 @@
33
This directory contains the C++ code for the LoRA demo.
44

55
You'll learn how to:
6-
1. Export two LoRA PTE files that share a single foundation weight file.
7-
2. Load and run the LoRA PTE files, and notice that the runtime memory is not doubled as the foundation weights are shared.
6+
1. Export LoRA PTE files that share a single foundation weight file.
7+
2. Load and run multiple LoRA PTE files at the same, and notice that the runtime memory increases by the LoRA adapter size (small) and not the foundation weight size (large), because the foundation weights are shared.
88

99
Note:
1010
- Weight-sharing is supported with the XNNPACK backend.
11-
- Quantization (outside of embedding quantization) is not supported when weight-sharing.
11+
- Quantization (outside of embedding quantization) is currently not supported when weight-sharing.
1212
- There are many ways to fine-tune LoRA adapters. We will go through a few examples to create a demo.
1313

14-
## Size savings.
14+
## Table of Contents
15+
- [Size Savings](#size-savings)
16+
- [Fine-tuning](#finetune-from-scratch-with-unsloth-and-llama)
17+
- [Installation](#install-executorch)
18+
- [Export models](#export-models)
19+
- [Run models](#install-runtime-dependencies)
20+
- [Demo video](#demo-video)
21+
22+
## Size savings
1523

1624
Size results will vary depending on the model and LoRA config. For this demo, we save ~5GB of disk space by storing weights in a separate, sharable file and ~5GB runtime memory by sharing weights at runtime through the XNNPACK weight cache. Detailed results are below.
1725

18-
### XNNPACK weight sharing.
26+
### XNNPACK weight sharing
1927

2028
The XNNPACK backend is a singleton. Weight sharing is implemented via the XNNPACK weight cache. At delegate init time, XNNPACK checks the weight cache for the weights it needs. If they don't exist, XNNPACK will fetch weights from the NamedDataMap (the API that exposes weights in a PTD file), pack them, store them in the weight cache and free the original. This means we won't keep around multiple copies of the same weights.
2129

22-
## Virtual environment setup.
23-
Create and activate a Python virtual environment:
24-
```bash
25-
python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip
26-
```
27-
Or alternatively, [install conda on your machine](https://conda.io/projects/conda/en/latest/user-guide/install/index.html)
30+
## Finetune from scratch with Unsloth and Llama
31+
[Unsloth](https://unsloth.ai/) provides a [colab notebook](https://docs.unsloth.ai/get-started/fine-tuning-llms-guide/datasets-guide#synthetic-dataset-notebook) that showcases how to generate data using the Meta Synthetic Data Kit, and then fine-tune it to create a LoRA adapter.
32+
33+
For this demo, we trained on two datasets:
34+
1. executorch/docs/source/: an adapter with domain knowledge of executorch. This used Meta Synthetic Data Kit to generate qa pairs based on the documentation.
35+
2. Recent Nobel prize winners (2024-2025): an adapter with knowledge beyond the cutoff date of Llama-3-2-1B. This data was taken from [Wikipedia](https://en.wikipedia.org/wiki/List_of_Nobel_laureates), and formatted into the chat template for training.
36+
37+
The training notebook takes a few shortcuts to reduce the latency/compute. You can change these settings for better results.
38+
1. When generating data, play around with the chunk sizes and overlap to see what works best for your dataset.
39+
2. At the training step, the notebook uses max_steps=60 to speed things up. Setting num_train_epochs=1 (or greater) for a full run and max_steps=None has better results.
40+
41+
Unsloth will output the adapter artifacts to the specified directory (in the colab notebook, 'lora_model/'). You will see a few files like such:
2842
```bash
29-
conda create -yn executorch-ptd python=3.10.0 && conda activate executorch-ptd
43+
-rw-r--r-- 1 lfq users 1092 Oct 15 11:01 adapter_config.json
44+
-rw-r--r-- 1 lfq users 45118424 Oct 15 11:01 adapter_model.safetensors
45+
-rw-r--r-- 1 lfq users 3827 Oct 15 11:01 chat_template.jinja
46+
-rw-r--r-- 1 lfq users 5268 Oct 15 11:01 README.md
47+
-rw-r--r-- 1 lfq users 454 Oct 15 11:01 special_tokens_map.json
48+
-rw-r--r-- 1 lfq users 50642 Oct 15 11:01 tokenizer_config.json
49+
-rw-r--r-- 1 lfq users 17209920 Oct 15 11:01 tokenizer.json
3050
```
3151

52+
The files we want are:
53+
- adapter_config.json
54+
- adapter_model.safetensors
55+
3256
## Install executorch
33-
Please install executorch. If you are using your own trained adapter (not the example one), please use a recent nightly build or install from source.
57+
[Install from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html#install-executorch-pip-package-from-source).
3458

3559
```
36-
pip install executorch==1.0.0
60+
# Move to the executorch subdirectory
61+
cd ~/executorch-examples/program-data-separation/cpp/executorch
62+
63+
# Update to recent main.
64+
git pull origin main
65+
66+
git submodule sync
67+
git submodule update --init --recursive
68+
69+
# Install ExecuTorch pip package.
70+
./install_executorch.sh --editable
3771
```
3872

39-
You can also install from the nightly build.
73+
You can also install from a recent nightly build.
4074
```
4175
pip install executorch==1.1.0.devYYYYMMDD --extra-index-url https://download.pytorch.org/whl/nightly/cpu
4276
```
4377

44-
Or [install from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html#install-executorch-pip-package-from-source).
78+
Use main or a recent nightly, as some features are not available in executorch==1.0.0.
4579

80+
## Export models
4681

47-
## Export the model/s.
48-
Change into the program-data-separation directory and create a directory to hold exported artifacts.
49-
```bash
50-
cd ~/executorch-examples/program-data-separation
51-
mkdir models
82+
1. Download the base model. We're using https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct.
5283
```
84+
pip install huggingface_hub
5385
54-
Export models into the `models` directory.
55-
- The first command generates a regular llama_3_2_1B model.
56-
- The second command generates a llama_3_2_1B lora model.
86+
# As this is a gated model, login.
87+
huggingface-cli login
88+
huggingface-cli download meta-llama/Llama-3.2-1B-Instruct --local-dir ./Llama-3.2-1B-Instruct
89+
```
5790

58-
```bash
59-
sh export_lora.sh
91+
2. Set your paths and the model name.
92+
```
93+
DOWNLOADED_PATH=Llama-3.2-1B-Instruct
94+
ADAPTER_PATH=lora_model
95+
MODEL_NAME=<model_name>
6096
```
61-
Expect the files:
62-
- llama_3_2_1B.pte
63-
- llama_3_2_1B.ptd
64-
- llama_3_2_1B_lora.pte
65-
- foundation_weights.ptd
66-
- tokenizer.model
6797

68-
llama_3_2_1B.ptd and foundation_weights.ptd contain the same contents, and you can remove llama_3_2_1B.ptd.
69-
tokenizer.model is copied from the temp directory where we downloaded the HF artifacts. It is used at runtime.
98+
3. Export command. Run this with different MODEL_NAMEs for each adapter.
99+
```
100+
python -m executorch.extension.llm.export.export_llm \
101+
base.checkpoint="${DOWNLOADED_PATH}/original/consolidated.00.pth" \
102+
base.params="${DOWNLOADED_PATH}/original/params.json" \
103+
base.tokenizer_path="${DOWNLOADED_PATH}/original/tokenizer.model" \
104+
base.adapter_checkpoint="${ADAPTER_PATH}/adapter_model.safetensors" \
105+
base.adapter_config="${ADAPTER_PATH}/adapter_config.json" \
106+
model.use_kv_cache=true \
107+
model.use_sdpa_with_kv_cache=true \
108+
model.dtype_override="fp32" \
109+
backend.xnnpack.enabled=true \
110+
backend.xnnpack.extended_ops=true \
111+
export.output_name="${MODEL_NAME}.pte" \
112+
export.foundation_weights_file="foundation.ptd"
113+
```
70114

71-
Note:
72-
- PTE: contains the program execution logic.
73-
- PTD: contains the constant tensors used by the PTE. This format is similar to safetensors. It relies on flatbuffers instead of json for serde.
115+
Expect to see two files: '<model_name>.pte' and 'foundation.ptd'. Run the command again to generate more adapter PTE files. You only need to keep one `foundation.ptd` file.
74116

75-
Sample file sizes:
76-
```
77-
-rw-r--r-- 1 lfq users 5994013600 Oct 17 14:31 foundation.ptd
117+
You can also run `~/executorch-examples/program-data-separation/export_lora.sh`. This will export the dummy lora model and the base Llama-3-2-1B model PTE files.
118+
119+
Example files, trained on executorch/docs/source/ and recent Nobel prize winners.
120+
```bash
121+
# executorch docs trained adapter model.
122+
-rw-r--r-- 1 lfq users 45555712 Oct 17 18:05 et.pte
123+
# foundation weight file
124+
-rw-r--r-- 1 lfq users 5994013600 Oct 17 18:05 foundation.ptd
125+
# dummy lora model.
78126
-rw-r--r-- 1 lfq users 27628928 Oct 17 14:31 llama_3_2_1B_lora.pte
79-
-rw-r--r-- 1 lfq users 317248 Oct 17 14:28 llama_3_2_1B.pte
127+
# Nobel prize winners trained adapter model.
128+
-rw-r--r-- 1 lfq users 45555712 Oct 17 18:00 nobel.pte
80129
```
81130

82-
Notice the lora - llama file size difference is about 27.3MB. This is the size of the adapter weights, and changes depending on the LoRA config. This demo is using the config from https://huggingface.co/lucylq/llama3_1B_lora/blob/main/adapter_config.json.
83-
```
84-
{"r": 64, "lora_alpha": 128, "target_modules": ["q_proj", "v_proj", "o_proj"], "peft_type": "LORA", "base_model_name_or_path": "meta-llama/Llama-3.2-1B-Instruct"}
85-
```
131+
Notice the adapter PTE files are about the same size as the `adapter_model.safetensors`/`adapter_model.pt` files generated during training. The PTE contains the adapter weights (which are not shared) and the program.
86132

87-
## Install runtime dependencies.
133+
## Install runtime dependencies
88134
The ExecuTorch repository is configured as a git submodule at `~/executorch-examples/program-data-separation/cpp/executorch`. To initialize it:
89135
```bash
90136
cd ~/executorch-examples/
137+
138+
# Update to the remote main branch.
139+
git submodule update --remote program-data-separation/cpp/executorch
91140
git submodule sync
92141
git submodule update --init --recursive
93142
```
94-
Install dev requirements for ExecuTorch:
95143

144+
Install dev requirements for ExecuTorch:
96145
```bash
97146
cd ~/executorch-examples/program-data-separation/cpp/executorch
98147
pip install -r requirements-dev.txt
99148
```
100149

101-
## Build the runtime.
150+
## Build the runtime
102151
Install some dependencies:
103152
```bash
104153
cd ~/executorch-examples/program-data-separation/cpp/executorch
@@ -111,39 +160,84 @@ cd ~/executorch-examples/program-data-separation/cpp/lora_example
111160
sh build_example.sh
112161
```
113162

114-
## Run the executable.
163+
## Run the executable
115164
```bash
116165
cd ~/executorch-examples/program-data-separation/cpp/lora_example
117166

167+
DOWNLOADED_PATH=~/path/to/Llama-3.2-1B-Instruct/
118168
./build/bin/executorch_program_data_separation \
119-
--tokenizer_path="../../tokenizer.model" \
120-
--model1="../../models/llama_3_2_1B_lora.pte" \
121-
--model2="../../models/llama_3_2_1B.pte" \
122-
--weights="../../models/foundation.ptd"
169+
--tokenizer_path="${DOWNLOADED_PATH}" \
170+
--model1="et.pte" \
171+
--model2="nobel.pte" \
172+
--weights="foundation.ptd" \
173+
--prompt="Who were the winners of the Nobel Prize in Physics in 2025?" \
174+
--apply_chat_template
123175
```
176+
Passing in the `DOWNLOADED_PATH` as the tokenizer directory will invoke the HFTokenizer, and parse additional tokenizers files: `tokenizer_config.json` and `special_tokens_map.json`. `special_tokens_map.json` tells us which bos/eos token to use, especially if there are multiple.
124177

125-
You should see some logs showing the Resident Set Size (RSS) at various points of the execution. Some sample logs may look like this:
178+
`apply_chat_template` formats the prompt according to the LLAMA chat template, which is what the adapter was trained on.
126179

180+
Sample output:
127181
```
128-
Generating with model <model file path>
129-
RSS after loading model: 6909.328125 MiB
130-
RSS after prompt prefill: 6909.328125 MiB
131-
RSS after finishing text generation: 6909.328125 MiB
132-
133-
Generating with lora...
134-
RSS after loading model: 7941.667969 MiB
135-
RSS after prompt prefill: 7941.667969 MiB
136-
RSS after finishing text generation: 7941.667969 MiB
182+
I 00:00:00.538779 executorch:main.cpp:133] Generating with model et.pte..
183+
...
184+
I 00:00:06.999737 executorch:text_llm_runner.cpp:182] RSS after prompt prefill: 6941.296875 MiB (0 if unsupported)
185+
I don't have information on the winners of the Nobel Prize in Physics in 2025.<|eot_id|>
186+
...
187+
I 00:00:11.635379 executorch:main.cpp:141] Generating with model nobel.pte...
188+
...
189+
I 00:00:14.109447 executorch:text_llm_runner.cpp:182] RSS after prompt prefill: 8041.632812 MiB (0 if unsupported)
190+
John Clarke, Michel H. Devoret, John M. Martinis<|eot_id|>
137191
```
138-
There is about ~1.4GB memory increase between running the two models.
139-
~1GB comes from embeddings that are not lowered to XNNPACK (and currently are not shared). This can be alleviated by quantizing the embeddings by adding the config `quantization.embedding_quantize=\'4,32\'` to the export command.
140-
~40MB comes from running the non-lora model, to running the lora model.
192+
We can see that the ExecuTorch-trained adapter model does not have knowledge of the recent Nobel Prize winners, as neither the base model or adapter was trained on it. Meanwhile, the Nobel-prize adapter model can answer well.
141193

142-
You can see the difference without weight-sharing by removing the flag `-DEXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE=True` from `build_example.sh`. Expect to see almost double the memory usage, ie. ~14-15GB instead of ~8GB.
194+
There is about ~1.1GB memory increase between running the two models.
195+
Most of that (about ~1GB) comes from embeddings that are not lowered to XNNPACK (and currently are not shared). This can be alleviated by quantizing the embeddings by adding the config `quantization.embedding_quantize=\'4,32\'` to the export command.
196+
~50MB comes from the adapter model, which is not shared.
143197

144-
## Clean up.
198+
Let's try with an executorch-specific prompt.
145199
```bash
146-
rm -rf build
147-
cd ~/executorch-examples/program-data-separation
148-
rm -rf models/
200+
cd ~/executorch-examples/program-data-separation/cpp/lora_example
201+
202+
DOWNLOADED_PATH=~/path/to/Llama-3.2-1B-Instruct/
203+
./build/bin/executorch_program_data_separation \
204+
--tokenizer_path="${DOWNLOADED_PATH}" \
205+
--model1="et.pte" \
206+
--model2="nobel.pte" \
207+
--weights="foundation.ptd" \
208+
--prompt="Help me get started with ExecuTorch in 3 steps" \
209+
--apply_chat_template
210+
```
211+
212+
Sample output:
149213
```
214+
...
215+
I 00:00:00.554048 executorch:main.cpp:133] Generating with model et.pte...
216+
...
217+
Here are 3 steps to get started with ExecuTorch:
218+
219+
Step 1: Install ExecuTorch dependencies. This includes installing Python 3.8+ library, PyTorch library, and the ExecuTorch runtime.
220+
221+
Step 2: Set up a Python environment with pip and a virtual environment (e.g., conda) to isolate ExecuTorch dependencies.
222+
223+
Step 3: Clone the Execu
224+
I 00:00:27.243400 executorch:text_llm_runner.cpp:206] RSS after finishing text generation: 6940.410156 MiB (0 if unsupported)
225+
...
226+
I 00:00:27.243504 executorch:main.cpp:141] Generating with model nobel.pte...
227+
...
228+
Here are the 3 steps to get started with Excetorch:
229+
230+
**Step 1: Install Node.js and npm**
231+
232+
Excetorch is a JavaScript compiler, so you'll need Node.js and npm (the Node Package Manager) installed on your computer. You can download Node.js from the official website and npm from the npm website. Follow the installation instructions for your operating system.
233+
234+
**Step 2: Install Excetorch**
235+
236+
237+
I 00:00:50.189743 executorch:text_llm_runner.cpp:206] RSS after finishing text generation: 8039.152344 MiB (0 if unsupported)
238+
```
239+
240+
The ExecuTorch-trained adapter model has domain knowledge of ExecuTorch codebase, whereas the Nobel-prize trained adapter model does not.
241+
242+
## Demo video
243+
https://github.com/user-attachments/assets/34f5488d-c1e3-4613-953f-f53745c9b01e

program-data-separation/cpp/lora_example/main.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ int main(int argc, char *argv[]) {
108108
llm::create_text_llm_runner(model1, std::move(tokenizer1),
109109
weights, temperature);
110110
std::unique_ptr<llm::TextLLMRunner> runner2 =
111-
llm::create_text_llm_runner(model1, std::move(tokenizer2),
111+
llm::create_text_llm_runner(model2, std::move(tokenizer2),
112112
weights, temperature);
113113

114114
llm::GenerationConfig config{
@@ -118,10 +118,11 @@ int main(int argc, char *argv[]) {
118118

119119
std::string formatted_prompt = std::string();
120120
if (FLAGS_apply_chat_template) {
121+
ET_LOG(Info, "Applying chat template...");
121122
// System Prompt.
122123
formatted_prompt += "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n";
124+
// User Prompt.
123125
formatted_prompt += "You are a helpful assistant.<|eot_id|>";
124-
// User prompt.
125126
formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n";
126127
formatted_prompt += prompt;
127128
formatted_prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>";
@@ -130,7 +131,6 @@ int main(int argc, char *argv[]) {
130131
}
131132

132133
ET_LOG(Info, "Generating with model %s...", model1);
133-
ET_LOG(Info, "Formatted prompt: %s", formatted_prompt.c_str());
134134
auto error = runner1->generate(formatted_prompt, config);
135135
if (error != Error::Ok) {
136136
ET_LOG(Error, "Failed to generate with model %s, error code %zu.",
@@ -145,6 +145,5 @@ int main(int argc, char *argv[]) {
145145
model2, error);
146146
return 1;
147147
}
148-
149148
return 0;
150149
}

0 commit comments

Comments
 (0)