Skip to content

Commit c79d28e

Browse files
authored
Save the model weights in a few hundred megabytes size like the BIAS-7B.pth provided by the official. (#75)
* support adapter weights extraction
1 parent f628a78 commit c79d28e

File tree

3 files changed

+92
-13
lines changed

3 files changed

+92
-13
lines changed

llama_adapter_v2_multimodal/docs/train.md

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
The training process of LLaMA-Adapter V2 consists of the pre-training and fine-tuning phases.
1+
The training process of LLaMA-Adapter V2 consists of the pre-training and fine-tuning phases.
22

33
## Pre-training
4+
45
### Data
5-
* We use multiple datasets with **image-text pairs** for pre-training. The texts are English-only.
66

7+
* We use multiple datasets with **image-text pairs** for pre-training. The texts are English-only.
78
* For each dataset, the meta file should be organized in the `.csv` format as following:
89

910
```
@@ -14,8 +15,8 @@ The training process of LLaMA-Adapter V2 consists of the pre-training and fine-t
1415
```
1516

1617
Alternatively, you may modify the [`PretrainDataset`](/data/dataset.py) implementation to adapt to your own meta file format.
17-
1818
* Write a `.yaml` config file to specify the datasets for pre-training:
19+
1920
```
2021
META:
2122
- '/path/to/cc3m.csv'
@@ -25,29 +26,25 @@ The training process of LLaMA-Adapter V2 consists of the pre-training and fine-t
2526

2627
### Start pre-training
2728

28-
We are now ready to start pre-training (please make sure that the original LLaMA weights are available in `/path/to/llama_model_weights`).
29+
We are now ready to start pre-training (please make sure that the original LLaMA weights are available in `/path/to/llama_model_weights`).
2930

3031
```bash
3132
. exps/pretrain.sh /path/to/llama_model_weights /path/to/pretrain-data-config.yaml /output/path
3233
```
3334

34-
35-
3635
## Fine-tuning
3736

3837
### Data
3938

4039
* We fine-tune LLaMA-Adapter V2 on text-only as well as image-text instruction following datasets.
41-
4240
* The following lists the datasets we use for training our release weights:
4341

44-
| Name | Link |
45-
| ------------------------ | ------------------------------------------------------------ |
46-
| alpaca_gpt4_data.json | [File Link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json) |
42+
| Name | Link |
43+
| ------------------------ | ------------------------------------------------------------------------------------------------------------ |
44+
| alpaca_gpt4_data.json | [File Link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json) |
4745
| alpaca_gpt4_data_zh.json | [File Link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data_zh.json) |
48-
| llava_instruct_150k.json | [File Link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/raw/main/llava_instruct_150k.json) |
49-
| alpaca_data_zh_51k.json | [File Link](https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/data/alpaca_data_zh_51k.json) |
50-
46+
| llava_instruct_150k.json | [File Link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/raw/main/llava_instruct_150k.json) |
47+
| alpaca_data_zh_51k.json | [File Link](https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/data/alpaca_data_zh_51k.json) |
5148
* Similar to pre-training, write a `.yaml` config file to specify the datasets for fine-tuning:
5249

5350
```
@@ -65,3 +62,32 @@ We are now ready to start pre-training (please make sure that the original LLaMA
6562
/path/to/finetune-data-config.yaml /output/path
6663
```
6764

65+
### Test and Save
66+
67+
```python
68+
import os
69+
from llama.llama_adapter import LLaMA_adapter
70+
import util.misc as misc
71+
import util.extract_adapter_from_checkpoint as extract
72+
73+
device = "cuda" if torch.cuda.is_available() else "cpu"
74+
75+
llama_dir = "path/to/llama/"
76+
llama_type = '7B'
77+
llama_ckpt_dir = os.path.join(llama_dir, llama_type)
78+
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model')
79+
model = LLaMA_adapter(llama_ckpt_dir, llama_tokenzier_path)
80+
81+
misc.load_model(model, 'path/to/finetune/checkpoint.pth')
82+
model.eval()
83+
model.to(device)
84+
85+
prompt = llama.format_prompt('your prompt')
86+
img = Image.fromarray(cv2.imread("your image"))
87+
img = model.clip_transform(img).unsqueeze(0).to(device)
88+
89+
result = model.generate(img, [prompt])[0]
90+
print(result)
91+
92+
extract.save(model,'path/to/adapter-7B.pth','BIAS') # Please end it with -llama_type.pth
93+
```

llama_adapter_v2_multimodal/llama/llama_adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def generate(
274274
return decoded
275275

276276

277+
277278
_MODELS = {
278279
"BIAS-7B": "https://github.com/OpenGVLab/LLaMA-Adapter/releases/download/v.2.0.0/7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth",
279280
"LORA-BIAS-7B": "https://github.com/OpenGVLab/LLaMA-Adapter/releases/download/v.2.0.0/1bcbffc43484332672092e0024a8699a6eb5f558161aebf98a7c6b1db67224d1_LORA-BIAS-7B.pth",
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
def save(full_model, path, model_type = 'BIAS'):
4+
if model_type == 'BIAS':
5+
keys = [
6+
f'visual_blocks.{i}.{key}.{suffix}'
7+
for i in range(8)
8+
for key in ['norm1', 'attn.qkv', 'attn.proj', 'norm2', 'mlp.fc1', 'mlp.fc2']
9+
for suffix in ['weight', 'bias']
10+
] + [
11+
f'llama.layers.{i}.{key}'
12+
for i in range(32)
13+
for key in ['attention.gate', 'attention.wq.bias', 'attention.wo.bias', 'feed_forward.w1.bias', 'feed_forward.w2.bias', 'feed_forward.w3.bias', 'attention_norm.weight', 'ffn_norm.weight']
14+
] + [
15+
f'{base_key}.{suffix}'
16+
for base_key in ['clip_proj_norm', 'visual_proj_norm', 'visual_proj', 'clip_proj']
17+
for suffix in ['weight', 'bias']
18+
] + ['llama.norm.weight', 'visual_query.weight', 'adapter_query.weight']
19+
20+
21+
elif model_type == 'LORA':
22+
keys = [
23+
f'visual_blocks.{i}.{key}.{suffix}'
24+
for i in range(8)
25+
for key in [f'norm{j}' for j in range(1, 3)] + ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']
26+
for suffix in ['weight', 'bias']
27+
] + [
28+
f'llama.layers.{i}.{key}'
29+
for i in range(32)
30+
for key in ['attention.gate', 'attention.wq.bias', 'attention.wo.bias', 'feed_forward.w1.bias', 'feed_forward.w2.bias', 'feed_forward.w3.bias', 'attention_norm.weight', 'ffn_norm.weight']
31+
+ [f'attention.lora_wk_l{j}.weight' for j in range(1, 3)]
32+
+ [f'attention.lora_wo_l{j}.weight' for j in range(1, 3)]
33+
+ [f'feed_forward.lora_w{k}_l{j}.weight' for k in range(1, 4) for j in range(1, 3)]
34+
+ [f'attention.lora_wq_l{j}.weight' for j in range(1, 3)]
35+
+ [f'attention.lora_wv_l{j}.weight' for j in range(1, 3)]
36+
+ ['attention.new_gate']
37+
] + [
38+
f'{base_key}.{suffix}'
39+
for base_key in ['clip_proj_norm', 'visual_proj_norm', 'visual_proj', 'clip_proj']
40+
for suffix in ['weight', 'bias']
41+
] + ['llama.norm.weight', 'visual_query.weight', 'adapter_query.weight']
42+
43+
## TODO: Add other model types
44+
45+
full_model_state_dict = full_model.state_dict()
46+
small_weights = {key: full_model_state_dict[key] for key in keys}
47+
if model_type == 'BIAS':
48+
wrapped_small_weights = {'model': small_weights,'config': {'w_bias': True, 'w_lora': False, 'lora_rank': 16}}
49+
elif model_type == 'LORA':
50+
wrapped_small_weights = {'model': small_weights,'config': {'w_bias': True, 'w_lora': True, 'lora_rank': 16}}
51+
# Save the wrapped small weights
52+
torch.save(wrapped_small_weights, path)

0 commit comments

Comments
 (0)