Skip to content

Commit fe7ec53

Browse files
committed
update pretrain and finetune codes of llama_adapter_v2_multimodal
1 parent f1e1911 commit fe7ec53

File tree

14 files changed

+1387
-48
lines changed

14 files changed

+1387
-48
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
import yaml
3+
from torch.utils.data import Dataset
4+
from PIL import Image
5+
import json
6+
import llama.utils
7+
from llama import Tokenizer
8+
import copy
9+
import torchvision.transforms as transforms
10+
import pandas as pd
11+
import random
12+
import cv2
13+
14+
try:
15+
from torchvision.transforms import InterpolationMode
16+
BICUBIC = InterpolationMode.BICUBIC
17+
except ImportError:
18+
BICUBIC = Image.BICUBIC
19+
20+
21+
PROMPT_DICT = {
22+
"prompt_input": (
23+
"Below is an instruction that describes a task, paired with an input that provides further context. "
24+
"Write a response that appropriately completes the request.\n\n"
25+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
26+
),
27+
"prompt_no_input": (
28+
"Below is an instruction that describes a task. "
29+
"Write a response that appropriately completes the request.\n\n"
30+
"### Instruction:\n{instruction}\n\n### Response:"
31+
),
32+
}
33+
34+
# create data
35+
transform_train = transforms.Compose([
36+
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
37+
antialias=None), # 3 is bicubic
38+
transforms.ToTensor(),
39+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
40+
41+
class FinetuneDataset(Dataset):
42+
def __init__(self, config_path, transform, max_words=30, tokenizer_path=None):
43+
print(f"read dataset config from {config_path}")
44+
with open(config_path, 'r') as f:
45+
self.config = yaml.load(f, Loader=yaml.FullLoader)
46+
print("DATASET CONFIG:")
47+
print(self.config)
48+
ann = []
49+
for meta_path in self.config['META']:
50+
meta_l = json.load(open(meta_path))
51+
print(f"{meta_path}: len {len(meta_l)}")
52+
ann += meta_l
53+
self.ann = ann
54+
print(f"total length: {len(self)}")
55+
self.transform = transform
56+
self.max_words = max_words
57+
self.tokenizer = Tokenizer(model_path=tokenizer_path)
58+
59+
def __len__(self):
60+
return len(self.ann)
61+
62+
def __getitem__(self, index):
63+
data_item = self.ann[index]
64+
if 'image' in data_item.keys():
65+
filename = data_item['image']
66+
question = data_item['conversations'][0]['value']
67+
answer = data_item['conversations'][1]['value']
68+
69+
image = cv2.imread(filename)
70+
image = Image.fromarray(image)
71+
image = self.transform(image)
72+
format_instruction = question
73+
format_input = None
74+
else:
75+
image = torch.zeros(3, 224, 224)
76+
format_instruction = data_item['instruction'],
77+
format_input = data_item['input']
78+
answer = data_item['output']
79+
input1 = llama.utils.format_prompt(format_instruction, format_input)
80+
input2 = input1 + answer
81+
input1 = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False), dtype=torch.int64)
82+
input2 = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True), dtype=torch.int64)
83+
padding = self.max_words - input2.shape[0]
84+
if padding > 0:
85+
input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.int64) - 1))
86+
elif padding < 0:
87+
input2 = input2[:self.max_words]
88+
labels = copy.deepcopy(input2)
89+
labels[:len(input1)] = -1
90+
input2_mask = input2.ge(0)
91+
label_mask = labels.ge(0)
92+
input2[~input2_mask] = 0
93+
labels[~label_mask] = 0
94+
input2_mask = input2_mask.float()
95+
label_mask = label_mask.float()
96+
return input2, labels, input2_mask, image
97+
98+
99+
class PretrainDataset(Dataset):
100+
def __init__(self, config_path, transform, max_words=30, tokenizer_path=None):
101+
print(f"read dataset config from {config_path}")
102+
with open(config_path, 'r') as f:
103+
self.config = yaml.load(f, Loader=yaml.FullLoader)
104+
print("DATASET CONFIG:")
105+
print(self.config)
106+
images, captions = [], []
107+
for meta_path in self.config['META']:
108+
images_this_meta, captions_this_meta = [], []
109+
for chunk in pd.read_csv(meta_path, sep='\t', lineterminator='\n', chunksize=10 ** 6):
110+
images_this_meta.extend(chunk['url'].tolist())
111+
captions_this_meta.extend(chunk['caption'].tolist())
112+
print(f"{meta_path}: len {len(images_this_meta)}")
113+
images.extend(images_this_meta)
114+
captions.extend(captions_this_meta)
115+
116+
self.data_list = []
117+
for x, y in zip(images, captions):
118+
self.data_list.append({'url': x, 'caption': y})
119+
print(f"total length: {len(self)}")
120+
self.transform = transform
121+
self.max_words = max_words
122+
self.tokenizer = Tokenizer(model_path=tokenizer_path)
123+
124+
def __len__(self):
125+
return len(self.data_list)
126+
127+
def __getitem__(self, index):
128+
sample = self.data_list[index]
129+
image_path, caption = sample['url'], sample['caption']
130+
if isinstance(caption, list):
131+
caption = random.choice(caption)
132+
caption = str(caption)
133+
134+
image = cv2.imread(image_path)
135+
image = Image.fromarray(image)
136+
image = self.transform(image)
137+
138+
format_instruction = "Generate caption of this image"
139+
input1 = llama.utils.format_prompt(format_instruction, None)
140+
input2 = input1 + caption
141+
142+
input1 = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False), dtype=torch.int64)
143+
input2 = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True), dtype=torch.int64)
144+
padding = self.max_words - input2.shape[0]
145+
if padding > 0:
146+
input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.int64) - 1))
147+
elif padding < 0:
148+
input2 = input2[:self.max_words]
149+
labels = copy.deepcopy(input2)
150+
labels[:len(input1)] = -1
151+
input2_mask = input2.ge(0)
152+
label_mask = labels.ge(0)
153+
input2[~input2_mask] = 0
154+
labels[~label_mask] = 0
155+
input2_mask = input2_mask.float()
156+
label_mask = label_mask.float()
157+
return input2, labels, input2_mask, image

llama_adapter_v2_multimodal/demo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
llama_dir = "/path/to/LLaMA/"
99

1010
model, preprocess = llama.load("BIAS-7B", llama_dir, device)
11+
model.eval()
1112

1213
prompt = llama.format_prompt('Please introduce this painting.')
13-
img = Image.fromarray(cv2.imread("../docs/logo_v1.png"))
14+
img = Image.fromarray(cv2.imread("./docs/logo_v1.png"))
1415
img = preprocess(img).unsqueeze(0).to(device)
1516

1617
result = model.generate(img, [prompt])[0]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
The training process of LLaMA-Adapter V2 consists of the pre-training and fine-tuning phases.
2+
3+
## Pre-training
4+
### Data
5+
* We use multiple datasets with **image-text pairs** for pre-training. The texts are English-only.
6+
7+
* For each dataset, the meta file should be organized in the `.csv` format as following:
8+
9+
```
10+
url caption
11+
/path/to/image1 caption1
12+
/path/to/image2 caption2
13+
...
14+
```
15+
16+
Alternatively, you may modify the [`PretrainDataset`](/data/dataset.py) implementation to adapt to your own meta file format.
17+
18+
* Write a `.yaml` config file to specify the datasets for pre-training:
19+
```
20+
META:
21+
- '/path/to/cc3m.csv'
22+
- '/path/to/coco.csv'
23+
...
24+
```
25+
26+
### Start pre-training
27+
28+
We are now ready to start pre-training (please make sure that the original LLaMA / Open-Chinese-LLaMA weights are available in `/path/to/llama_model_weights`).
29+
30+
```bash
31+
. exps/pretrain.sh /path/to/llama_model_weights /path/to/pretrain-data-config.yaml /output/path
32+
```
33+
34+
35+
36+
## Fine-tuning
37+
38+
### Data
39+
40+
* We fine-tune LLaMA-Adapter V2 on text-only as well as image-text instruction following datasets.
41+
42+
* The following lists the datasets we use for training our release weights:
43+
44+
| Name | Link |
45+
| ------------------------ | ------------------------------------------------------------ |
46+
| alpaca_gpt4_data.json | [File Link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json) |
47+
| alpaca_gpt4_data_zh.json | [File Link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data_zh.json) |
48+
| llava_instruct_150k.json | [File Link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/raw/main/llava_instruct_150k.json) |
49+
| alpaca_data_zh_51k.json | [File Link](https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/data/alpaca_data_zh_51k.json) |
50+
51+
* Similar to pre-training, write a `.yaml` config file to specify the datasets for fine-tuning:
52+
53+
```
54+
META:
55+
- '/path/to/alpaca_gpt4_data.json'
56+
- '/path/to/alpaca_gpt4_data_zh.json'
57+
...
58+
```
59+
60+
### Start fine-tuning
61+
62+
```bash
63+
. exps/finetune.sh \
64+
/path/to/llama_model_weights /path/to/pre-trained/checkopint.pth \
65+
/path/to/finetune-data-config.yaml /output/path
66+
```
67+
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import math
2+
import sys
3+
from typing import Iterable
4+
5+
import torch
6+
7+
import util.misc as misc
8+
import util.lr_sched as lr_sched
9+
10+
from llama import LLaMA_adapter
11+
12+
def train_one_epoch(model: LLaMA_adapter,
13+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
14+
device: torch.device, epoch: int, loss_scaler,
15+
log_writer=None,
16+
args=None):
17+
model.train(True)
18+
# model.module.set_default_trainability()
19+
20+
metric_logger = misc.MetricLogger(delimiter=" ")
21+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
22+
header = 'Epoch: [{}]'.format(epoch)
23+
print_freq = 10
24+
25+
accum_iter = args.accum_iter
26+
27+
optimizer.zero_grad()
28+
29+
if log_writer is not None:
30+
print('log_dir: {}'.format(log_writer.log_dir))
31+
for data_iter_step, (examples, labels, example_mask, imgs) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
32+
# we use a per iteration (instead of per epoch) lr scheduler
33+
if data_iter_step % accum_iter == 0:
34+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
35+
36+
imgs = imgs.to(device, non_blocking=True)
37+
with torch.cuda.amp.autocast():
38+
c_loss, m_loss = model(examples, labels, imgs)
39+
loss = c_loss + m_loss * 0
40+
loss_value = loss.item()
41+
c_loss_value = c_loss.item()
42+
m_loss_value = m_loss
43+
if not math.isfinite(loss_value):
44+
print("Loss is {}, stopping training".format(loss_value))
45+
sys.exit(1)
46+
47+
loss /= accum_iter
48+
loss_scaler(loss, optimizer, parameters=model.parameters(),
49+
update_grad=(data_iter_step + 1) % accum_iter == 0)
50+
if (data_iter_step + 1) % accum_iter == 0:
51+
optimizer.zero_grad()
52+
53+
torch.cuda.synchronize()
54+
55+
metric_logger.update(closs=c_loss_value)
56+
metric_logger.update(mloss=m_loss_value)
57+
58+
lr = optimizer.param_groups[0]["lr"]
59+
metric_logger.update(lr=lr)
60+
61+
loss_value_reduce = misc.all_reduce_mean(loss_value)
62+
c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
63+
m_loss_value_reduce = misc.all_reduce_mean(m_loss_value)
64+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
65+
""" We use epoch_1000x as the x-axis in tensorboard.
66+
This calibrates different curves when batch size changes.
67+
"""
68+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
69+
log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x)
70+
log_writer.add_scalar('m_train_loss', m_loss_value_reduce, epoch_1000x)
71+
log_writer.add_scalar('lr', lr, epoch_1000x)
72+
73+
74+
# gather the stats from all processes
75+
metric_logger.synchronize_between_processes()
76+
print("Averaged stats:", metric_logger)
77+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import math
2+
import sys
3+
from typing import Iterable
4+
5+
import torch
6+
7+
import util.misc as misc
8+
import util.lr_sched as lr_sched
9+
10+
from llama import LLaMA_adapter
11+
12+
def train_one_epoch(model: LLaMA_adapter,
13+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
14+
device: torch.device, epoch: int, loss_scaler,
15+
log_writer=None,
16+
args=None):
17+
model.train(True)
18+
# model.module.set_default_trainability()
19+
20+
metric_logger = misc.MetricLogger(delimiter=" ")
21+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
22+
header = 'Epoch: [{}]'.format(epoch)
23+
print_freq = 10
24+
25+
accum_iter = args.accum_iter
26+
27+
optimizer.zero_grad()
28+
29+
if log_writer is not None:
30+
print('log_dir: {}'.format(log_writer.log_dir))
31+
for data_iter_step, (examples, labels, example_mask, imgs) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
32+
# we use a per iteration (instead of per epoch) lr scheduler
33+
if data_iter_step % accum_iter == 0:
34+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
35+
36+
imgs = imgs.to(device, non_blocking=True)
37+
with torch.cuda.amp.autocast():
38+
c_loss, m_loss = model(examples, labels, imgs)
39+
loss = c_loss + m_loss * 0
40+
loss_value = loss.item()
41+
c_loss_value = c_loss.item()
42+
m_loss_value = m_loss
43+
if not math.isfinite(loss_value):
44+
print("Loss is {}, stopping training".format(loss_value))
45+
sys.exit(1)
46+
47+
loss /= accum_iter
48+
loss_scaler(loss, optimizer, parameters=model.parameters(),
49+
update_grad=(data_iter_step + 1) % accum_iter == 0)
50+
if (data_iter_step + 1) % accum_iter == 0:
51+
optimizer.zero_grad()
52+
53+
torch.cuda.synchronize()
54+
55+
metric_logger.update(closs=c_loss_value)
56+
metric_logger.update(mloss=m_loss_value)
57+
58+
lr = optimizer.param_groups[0]["lr"]
59+
metric_logger.update(lr=lr)
60+
61+
loss_value_reduce = misc.all_reduce_mean(loss_value)
62+
c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
63+
m_loss_value_reduce = misc.all_reduce_mean(m_loss_value)
64+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
65+
""" We use epoch_1000x as the x-axis in tensorboard.
66+
This calibrates different curves when batch size changes.
67+
"""
68+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
69+
log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x)
70+
log_writer.add_scalar('m_train_loss', m_loss_value_reduce, epoch_1000x)
71+
log_writer.add_scalar('lr', lr, epoch_1000x)
72+
73+
74+
# gather the stats from all processes
75+
metric_logger.synchronize_between_processes()
76+
print("Averaged stats:", metric_logger)
77+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

0 commit comments

Comments
 (0)