Skip to content

Commit 3b1b6e8

Browse files
committed
Fix retrieval finetune bug & support coco finetune
1 parent 7725e00 commit 3b1b6e8

9 files changed

+427
-161
lines changed

internvl_g/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@ Coming Soon
1616

1717
Three datasets need to be prepared: COCO Caption, Flickr30K, and NoCaps.
1818

19+
You can download the `coco_karpathy_train.json` from [here](https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json).
20+
1921
```shell
2022
data
2123
├── coco
2224
│   ├── annotations
25+
│   │   ├── coco_karpathy_train.json
2326
│   ├── test2017
2427
│   ├── train2014
2528
│   ├── train2017
@@ -78,6 +81,12 @@ To fine-tune InternVL on Flickr30K-CN with 32 GPUs, run:
7881
sh shell/finetune/internvl_stage2_finetune_flickrcn_364_bs1024_ep10.sh
7982
```
8083

84+
To fine-tune InternVL on COCO with 32 GPUs, run:
85+
86+
```shell
87+
sh shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh
88+
```
89+
8190
## 📊 Evaluation
8291

8392
### Zero-Shot Image Captioning

internvl_g/internvl/model/internvl_stage2/modeling_internvl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def forward(
375375
image_attention_mask, attention_mask, input_embeds, repeat_time
376376
)
377377
if type(self.qllama.model) == LlamaForCausalLM:
378-
outputs = self.qllama.model.model.custom_forward(
378+
outputs = self.qllama.model.model.forward_train(
379379
inputs_embeds=input_embeds,
380380
vision_hidden_states=image_embeds,
381381
attention_mask=attention_mask,
@@ -385,7 +385,7 @@ def forward(
385385
repeat_time=repeat_time,
386386
).last_hidden_state
387387
else:
388-
outputs = self.qllama.model.custom_forward(
388+
outputs = self.qllama.model.forward_train(
389389
inputs_embeds=input_embeds,
390390
vision_hidden_states=image_embeds,
391391
attention_mask=attention_mask,

internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def forward(
367367
image_attention_mask, attention_mask, input_embeds, repeat_time
368368
)
369369
if type(self.qllama.model) == LlamaForCausalLM:
370-
outputs = self.qllama.model.model.custom_forward(
370+
outputs = self.qllama.model.model.forward_train(
371371
inputs_embeds=input_embeds,
372372
vision_hidden_states=image_embeds,
373373
attention_mask=attention_mask,
@@ -377,7 +377,7 @@ def forward(
377377
repeat_time=repeat_time,
378378
).last_hidden_state
379379
else:
380-
outputs = self.qllama.model.custom_forward(
380+
outputs = self.qllama.model.forward_train(
381381
inputs_embeds=input_embeds,
382382
vision_hidden_states=image_embeds,
383383
attention_mask=attention_mask,
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
import json
2+
import random
3+
import re
4+
from typing import Dict
5+
6+
import torch
7+
import torchvision.transforms as T
8+
from PIL import Image
9+
from torch.utils.data import Dataset
10+
from torchvision.transforms.functional import InterpolationMode
11+
12+
13+
def build_transform(input_size):
14+
# match fine-tune setting with blip2
15+
# https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py
16+
transform = T.Compose([
17+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
18+
T.RandomResizedCrop(input_size, scale=(0.5, 1.0),
19+
interpolation=InterpolationMode.BICUBIC),
20+
T.RandomHorizontalFlip(),
21+
T.ToTensor(),
22+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
23+
])
24+
return transform
25+
26+
27+
class FlickrDataset(Dataset):
28+
"""Dataset for supervised fine-tuning."""
29+
30+
def __init__(self, metas, tokenizer, data_args):
31+
super(FlickrDataset, self).__init__()
32+
33+
f = open(metas['annotation'])
34+
lines = f.readlines()[1:]
35+
36+
self.data_args = data_args
37+
self.tokenizer = tokenizer
38+
self.images = []
39+
self.image_ids = []
40+
self.captions = []
41+
42+
for line in lines:
43+
image, caption = line.strip().split('.jpg,')
44+
image_id = int(image)
45+
caption = self.process_single_caption(caption)
46+
image = image + '.jpg'
47+
image_path = metas['root'] + '/' + image
48+
self.images.append(image_path)
49+
self.image_ids.append(image_id)
50+
self.captions.append(caption)
51+
print(f'There are {len(self.images)} images.')
52+
print(f'There are {len(self.captions)} captions.')
53+
54+
def __len__(self):
55+
return len(self.images)
56+
57+
def process_single_caption(self, caption, max_words=50):
58+
caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
59+
caption = re.sub(r'\s{2,}', ' ', caption)
60+
caption = caption.rstrip('\n')
61+
caption = caption.strip(' ')
62+
63+
# truncate caption
64+
caption_words = caption.split(' ')
65+
if len(caption_words) > max_words:
66+
caption = ' '.join(caption_words[: max_words])
67+
return caption
68+
69+
def preprocess(self, image, caption, neg_caption):
70+
model_inputs = dict()
71+
72+
# input image
73+
image_transform = build_transform(input_size=self.data_args.force_image_size)
74+
image = Image.open(image)
75+
image = image.convert('RGB')
76+
pixel_values = image_transform(image)
77+
model_inputs['pixel_values'] = pixel_values
78+
79+
# for image-text matching
80+
pos_model_inputs = self.tokenizer(
81+
caption,
82+
max_length=self.data_args.max_seq_length,
83+
padding='max_length' if self.data_args.pad_to_max_length else False,
84+
truncation=True,
85+
return_tensors='pt',
86+
)
87+
model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
88+
model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
89+
neg_model_inputs = self.tokenizer(
90+
neg_caption,
91+
max_length=self.data_args.max_seq_length,
92+
padding='max_length' if self.data_args.pad_to_max_length else False,
93+
truncation=True,
94+
return_tensors='pt',
95+
)
96+
model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
97+
model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
98+
99+
# for image-text contrastive learning
100+
summarize_model_inputs = self.tokenizer(
101+
'summarize:' + caption,
102+
max_length=self.data_args.max_seq_length,
103+
padding='max_length' if self.data_args.pad_to_max_length else False,
104+
truncation=True,
105+
return_tensors='pt',
106+
)
107+
model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
108+
model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
109+
110+
# for image-grounded text generation
111+
prefix = f'English caption:'
112+
content = caption
113+
tokenized_prefix = self.tokenizer(
114+
prefix, padding=False, truncation=True, return_tensors='pt',
115+
)
116+
prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
117+
prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
118+
tokenized_content = self.tokenizer(
119+
content,
120+
max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
121+
padding='max_length' if self.data_args.pad_to_max_length else False,
122+
truncation=True,
123+
return_tensors='pt',
124+
)
125+
content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
126+
content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
127+
model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
128+
model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
129+
labels = model_inputs['input_ids'].clone()
130+
labels[labels == self.tokenizer.pad_token_id] = -100
131+
labels[:, :prefix_input_ids.size(1) - 1] = -100
132+
model_inputs['labels'] = labels
133+
return model_inputs
134+
135+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
136+
i = i % len(self.images)
137+
j = random.randint(0, len(self.images) - 1)
138+
while self.image_ids[j] == self.image_ids[i]:
139+
j = random.randint(0, len(self.images) - 1)
140+
ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
141+
# for image-text matching
142+
ret['positive_input_ids'] = ret['positive_input_ids'][0]
143+
ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
144+
ret['negative_input_ids'] = ret['negative_input_ids'][0]
145+
ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
146+
# for image-text contrastive learning
147+
ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
148+
ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
149+
# for image-grounded text generation
150+
ret['input_ids'] = ret['input_ids'][0]
151+
ret['attention_mask'] = ret['attention_mask'][0]
152+
ret['labels'] = ret['labels'][0]
153+
ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
154+
return ret
155+
156+
157+
class COCODataset(Dataset):
158+
"""Dataset for supervised fine-tuning."""
159+
160+
def __init__(self, metas, tokenizer, data_args):
161+
super(COCODataset, self).__init__()
162+
163+
annotations = json.load(open(metas['annotation']))
164+
165+
self.data_args = data_args
166+
self.tokenizer = tokenizer
167+
self.images = []
168+
self.image_ids = []
169+
self.captions = []
170+
171+
for annotation in annotations:
172+
image_id = int(annotation['image_id'].split('_')[-1])
173+
caption = annotation['caption']
174+
caption = self.process_single_caption(caption)
175+
image = annotation['image']
176+
image_path = metas['root'] + '/' + image
177+
self.images.append(image_path)
178+
self.image_ids.append(image_id)
179+
self.captions.append(caption)
180+
print(f'There are {len(self.images)} images.')
181+
print(f'There are {len(self.captions)} captions.')
182+
183+
def __len__(self):
184+
return len(self.images)
185+
186+
def process_single_caption(self, caption, max_words=50):
187+
caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
188+
caption = re.sub(r'\s{2,}', ' ', caption)
189+
caption = caption.rstrip('\n')
190+
caption = caption.strip(' ')
191+
192+
# truncate caption
193+
caption_words = caption.split(' ')
194+
if len(caption_words) > max_words:
195+
caption = ' '.join(caption_words[: max_words])
196+
return caption
197+
198+
def preprocess(self, image, caption, neg_caption):
199+
model_inputs = dict()
200+
201+
# input image
202+
image_transform = build_transform(input_size=self.data_args.force_image_size)
203+
image = Image.open(image)
204+
image = image.convert('RGB')
205+
pixel_values = image_transform(image)
206+
model_inputs['pixel_values'] = pixel_values
207+
208+
# for image-text matching
209+
pos_model_inputs = self.tokenizer(
210+
caption,
211+
max_length=self.data_args.max_seq_length,
212+
padding='max_length' if self.data_args.pad_to_max_length else False,
213+
truncation=True,
214+
return_tensors='pt',
215+
)
216+
model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
217+
model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
218+
neg_model_inputs = self.tokenizer(
219+
neg_caption,
220+
max_length=self.data_args.max_seq_length,
221+
padding='max_length' if self.data_args.pad_to_max_length else False,
222+
truncation=True,
223+
return_tensors='pt',
224+
)
225+
model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
226+
model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
227+
228+
# for image-text contrastive learning
229+
summarize_model_inputs = self.tokenizer(
230+
'summarize:' + caption,
231+
max_length=self.data_args.max_seq_length,
232+
padding='max_length' if self.data_args.pad_to_max_length else False,
233+
truncation=True,
234+
return_tensors='pt',
235+
)
236+
model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
237+
model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
238+
239+
# for image-grounded text generation
240+
prefix = f'English caption:'
241+
content = caption
242+
tokenized_prefix = self.tokenizer(
243+
prefix, padding=False, truncation=True, return_tensors='pt',
244+
)
245+
prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
246+
prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
247+
tokenized_content = self.tokenizer(
248+
content,
249+
max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
250+
padding='max_length' if self.data_args.pad_to_max_length else False,
251+
truncation=True,
252+
return_tensors='pt',
253+
)
254+
content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
255+
content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
256+
model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
257+
model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
258+
labels = model_inputs['input_ids'].clone()
259+
labels[labels == self.tokenizer.pad_token_id] = -100
260+
labels[:, :prefix_input_ids.size(1) - 1] = -100
261+
model_inputs['labels'] = labels
262+
return model_inputs
263+
264+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
265+
i = i % len(self.images)
266+
j = random.randint(0, len(self.images) - 1)
267+
while self.image_ids[j] == self.image_ids[i]:
268+
j = random.randint(0, len(self.images) - 1)
269+
ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
270+
# for image-text matching
271+
ret['positive_input_ids'] = ret['positive_input_ids'][0]
272+
ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
273+
ret['negative_input_ids'] = ret['negative_input_ids'][0]
274+
ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
275+
# for image-text contrastive learning
276+
ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
277+
ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
278+
# for image-grounded text generation
279+
ret['input_ids'] = ret['input_ids'][0]
280+
ret['attention_mask'] = ret['attention_mask'][0]
281+
ret['labels'] = ret['labels'][0]
282+
ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
283+
return ret

0 commit comments

Comments
 (0)