Skip to content

Commit 88f1e92

Browse files
committed
Support dynamic resolution in evaluation
1 parent b1f8bea commit 88f1e92

File tree

19 files changed

+884
-941
lines changed

19 files changed

+884
-941
lines changed

internvl_chat/eval/caption/evaluate_caption.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from functools import partial
88

99
import torch
10-
from internvl.train.dataset import build_transform
10+
from internvl.model.internvl_chat import InternVLChatModel
11+
from internvl.train.dataset import build_transform, dynamic_preprocess
1112
from PIL import Image
1213
from pycocoevalcap.eval import COCOEvalCap
1314
from pycocotools.coco import COCO
1415
from tqdm import tqdm
15-
from transformers import LlamaTokenizer
16+
from transformers import AutoTokenizer
1617

1718
ds_collections = {
1819
'flickr30k': {
@@ -39,15 +40,20 @@
3940

4041
class CaptionDataset(torch.utils.data.Dataset):
4142

42-
def __init__(self, name, root, annotation, prompt, input_size=224, pad2square=False):
43+
def __init__(self, name, root, annotation, prompt, input_size=224, dynamic_image_size=False,
44+
use_thumbnail=False, max_num=6):
4345
if name == 'coco':
4446
self.images = json.load(open(annotation))
4547
else:
4648
self.images = json.load(open(annotation))['images']
4749
self.name = name
4850
self.prompt = prompt
4951
self.root = root
50-
self.transform = build_transform(is_train=False, input_size=input_size, pad2square=pad2square)
52+
self.input_size = input_size
53+
self.dynamic_image_size = dynamic_image_size
54+
self.use_thumbnail = use_thumbnail
55+
self.max_num = max_num
56+
self.transform = build_transform(is_train=False, input_size=input_size)
5157

5258
def __len__(self):
5359
return len(self.images)
@@ -65,7 +71,14 @@ def __getitem__(self, idx):
6571
image_path = os.path.join(self.root, self.images[idx]['image'])
6672

6773
image = Image.open(image_path)
68-
pixel_values = self.transform(image).unsqueeze(0)
74+
if self.dynamic_image_size:
75+
images = dynamic_preprocess(image, image_size=self.input_size,
76+
use_thumbnail=self.use_thumbnail,
77+
max_num=self.max_num)
78+
else:
79+
images = [image]
80+
pixel_values = [self.transform(image) for image in images]
81+
pixel_values = torch.stack(pixel_values)
6982

7083
return {
7184
'image_id': image_id,
@@ -125,7 +138,9 @@ def evaluate_chat_model():
125138
annotation=annotation,
126139
prompt=prompt,
127140
input_size=image_size,
128-
pad2square=pad2square
141+
dynamic_image_size=args.dynamic,
142+
use_thumbnail=use_thumbnail,
143+
max_num=args.max_num
129144
)
130145
dataloader = torch.utils.data.DataLoader(
131146
dataset=dataset,
@@ -151,7 +166,7 @@ def evaluate_chat_model():
151166
tokenizer=tokenizer,
152167
pixel_values=pixel_values,
153168
question=prompt,
154-
generation_config=generation_config,
169+
generation_config=generation_config
155170
)
156171
image_ids.extend(ids)
157172
captions.extend([pred])
@@ -217,6 +232,8 @@ def evaluate_chat_model():
217232
parser.add_argument('--temperature', type=float, default=0.0)
218233
parser.add_argument('--out-dir', type=str, default='results')
219234
parser.add_argument('--seed', type=int, default=0)
235+
parser.add_argument('--dynamic', action='store_true')
236+
parser.add_argument('--max-num', type=int, default=6)
220237
args = parser.parse_args()
221238

222239
if not os.path.exists(args.out_dir):
@@ -234,29 +251,22 @@ def evaluate_chat_model():
234251

235252
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
236253

237-
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
238-
239-
if 'qllama' in args.checkpoint.lower():
240-
from internvl.model.internvl_chat_with_qllama import InternVLChatModel
241-
model = InternVLChatModel.from_pretrained(
242-
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
243-
image_size = model.internvl.config.force_image_size or model.config.internvl_config.vision_config.image_size
244-
pad2square = model.config.pad2square
245-
else:
246-
from internvl.model.internvl_chat import InternVLChatModel
247-
model = InternVLChatModel.from_pretrained(
248-
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
249-
image_size = model.config.force_image_size or model.config.vision_config.image_size
250-
pad2square = model.config.pad2square
254+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
255+
model = InternVLChatModel.from_pretrained(
256+
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
257+
image_size = model.config.force_image_size or model.config.vision_config.image_size
258+
use_thumbnail = model.config.use_thumbnail
251259

252260
total_params = sum(p.numel() for p in model.parameters()) / 1e9
253-
if total_params > 30:
261+
if total_params > 20:
254262
args.num_beams = 1
255263
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
256264
else:
257265
print(f'[test] total_params: {total_params}B')
258266
print(f'[test] image_size: {image_size}')
259-
print(f'[test] pad2square: {pad2square}')
260267
print(f'[test] template: {model.config.template}')
268+
print(f'[test] dynamic_image_size: {args.dynamic}')
269+
print(f'[test] use_thumbnail: {use_thumbnail}')
270+
print(f'[test] max_num: {args.max_num}')
261271

262272
evaluate_chat_model()

internvl_chat/eval/cmmmu/evaluate_cmmmu.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import random
55

66
import torch
7-
from internvl.train.dataset import build_transform
7+
from internvl.model.internvl_chat import InternVLChatModel
8+
from internvl.train.dataset import build_transform, dynamic_preprocess
89
from PIL import Image
910
from tqdm import tqdm
10-
from transformers import LlamaTokenizer
11+
from transformers import AutoTokenizer
1112

1213
ds_collections = {
1314
'art_and_design': {
@@ -51,15 +52,20 @@
5152

5253
class VQADataset(torch.utils.data.Dataset):
5354

54-
def __init__(self, root, annotation, input_size=224, pad2square=False):
55+
def __init__(self, root, annotation, input_size=224, dynamic_image_size=False,
56+
use_thumbnail=False, max_num=6):
5557
self.root = root
5658
self.items = []
5759
f = open(annotation)
5860
data = f.readlines()
5961
for data_line in data:
6062
data_line = json.loads(data_line)
6163
self.items.append(data_line)
62-
self.transform = build_transform(is_train=False, input_size=input_size, pad2square=pad2square)
64+
self.input_size = input_size
65+
self.dynamic_image_size = dynamic_image_size
66+
self.use_thumbnail = use_thumbnail
67+
self.max_num = max_num
68+
self.transform = build_transform(is_train=False, input_size=input_size)
6369

6470
def __len__(self):
6571
return len(self.items)
@@ -69,7 +75,15 @@ def __getitem__(self, idx):
6975
image_path, question = item['image'], item['text']
7076
image_path = os.path.join(self.root, image_path)
7177
image = Image.open(image_path).convert('RGB')
72-
pixel_values = self.transform(image).unsqueeze(0)
78+
if self.dynamic_image_size:
79+
images = dynamic_preprocess(image, image_size=self.input_size,
80+
use_thumbnail=self.use_thumbnail,
81+
max_num=self.max_num)
82+
else:
83+
images = [image]
84+
pixel_values = [self.transform(image) for image in images]
85+
pixel_values = torch.stack(pixel_values)
86+
7387
return {
7488
'question': question,
7589
'pixel_values': pixel_values,
@@ -85,7 +99,9 @@ def evaluate_chat_model():
8599
root=ds_collections[ds_name]['root'],
86100
annotation=ds_collections[ds_name]['annotation'],
87101
input_size=image_size,
88-
pad2square=pad2square
102+
dynamic_image_size=args.dynamic,
103+
use_thumbnail=use_thumbnail,
104+
max_num=args.max_num
89105
)
90106

91107
print(f'Evaluating {ds_name} ...')
@@ -109,9 +125,8 @@ def evaluate_chat_model():
109125
tokenizer=tokenizer,
110126
pixel_values=pixel_value,
111127
question=question,
112-
generation_config=generation_config,
128+
generation_config=generation_config
113129
)
114-
print(question, pred)
115130
question_id = item['question_id']
116131
text = item['text']
117132
output = {
@@ -137,6 +152,8 @@ def evaluate_chat_model():
137152
parser.add_argument('--temperature', type=float, default=0.0)
138153
parser.add_argument('--out-dir', type=str, default='results')
139154
parser.add_argument('--seed', type=int, default=0)
155+
parser.add_argument('--dynamic', action='store_true')
156+
parser.add_argument('--max-num', type=int, default=6)
140157
args = parser.parse_args()
141158

142159
if not os.path.exists(args.out_dir):
@@ -146,30 +163,23 @@ def evaluate_chat_model():
146163
print('datasets:', args.datasets)
147164
assert args.batch_size == 1, 'Only batch size 1 is supported'
148165

149-
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
150-
151-
if 'qllama' in args.checkpoint.lower():
152-
from internvl.model.internvl_chat_with_qllama import InternVLChatModel
153-
model = InternVLChatModel.from_pretrained(
154-
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
155-
image_size = model.internvl.config.force_image_size or model.config.internvl_config.vision_config.image_size
156-
pad2square = model.config.pad2square
157-
else:
158-
from internvl.model.internvl_chat import InternVLChatModel
159-
model = InternVLChatModel.from_pretrained(
160-
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
161-
image_size = model.config.force_image_size or model.config.vision_config.image_size
162-
pad2square = model.config.pad2square
166+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
167+
model = InternVLChatModel.from_pretrained(
168+
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
169+
image_size = model.config.force_image_size or model.config.vision_config.image_size
170+
use_thumbnail = model.config.use_thumbnail
163171

164172
total_params = sum(p.numel() for p in model.parameters()) / 1e9
165-
if total_params > 30:
173+
if total_params > 20:
166174
args.num_beams = 1
167175
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
168176
else:
169177
print(f'[test] total_params: {total_params}B')
170178
print(f'[test] image_size: {image_size}')
171-
print(f'[test] pad2square: {pad2square}')
172179
print(f'[test] template: {model.config.template}')
180+
print(f'[test] dynamic_image_size: {args.dynamic}')
181+
print(f'[test] use_thumbnail: {use_thumbnail}')
182+
print(f'[test] max_num: {args.max_num}')
173183

174184
model_id = '_'.join(args.checkpoint.split('/')[-2:])
175185
evaluate_chat_model()

internvl_chat/eval/llava_bench/evaluate_llava_bench.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import random
55

66
import torch
7-
from internvl.train.dataset import build_transform
7+
from internvl.model.internvl_chat import InternVLChatModel
8+
from internvl.train.dataset import build_transform, dynamic_preprocess
89
from PIL import Image
910
from tqdm import tqdm
10-
from transformers import LlamaTokenizer
11+
from transformers import AutoTokenizer
1112

1213
ds_collections = {
1314
'llava_bench': {
@@ -21,11 +22,16 @@
2122

2223
class VQADataset(torch.utils.data.Dataset):
2324

24-
def __init__(self, root, data, prompt, input_size=224, pad2square=False):
25+
def __init__(self, root, data, prompt, input_size=224, dynamic_image_size=False,
26+
use_thumbnail=False, max_num=6):
2527
self.root = root
2628
self.data = open(data).readlines()
2729
self.prompt = prompt
28-
self.transform = build_transform(is_train=False, input_size=input_size, pad2square=pad2square)
30+
self.input_size = input_size
31+
self.dynamic_image_size = dynamic_image_size
32+
self.use_thumbnail = use_thumbnail
33+
self.max_num = max_num
34+
self.transform = build_transform(is_train=False, input_size=input_size)
2935

3036
def __len__(self):
3137
return len(self.data)
@@ -37,7 +43,14 @@ def __getitem__(self, idx):
3743

3844
image = os.path.join(self.root, image)
3945
image = Image.open(image).convert('RGB')
40-
pixel_values = self.transform(image).unsqueeze(0)
46+
if self.dynamic_image_size:
47+
images = dynamic_preprocess(image, image_size=self.input_size,
48+
use_thumbnail=self.use_thumbnail,
49+
max_num=self.max_num)
50+
else:
51+
images = [image]
52+
pixel_values = [self.transform(image) for image in images]
53+
pixel_values = torch.stack(pixel_values)
4154
question = question + self.prompt
4255
return question_id, question, pixel_values, annotation
4356

@@ -51,7 +64,9 @@ def evaluate_chat_model():
5164
data=ds_collections[ds_name]['question'],
5265
prompt=' Please give a detailed answer.',
5366
input_size=image_size,
54-
pad2square=pad2square
67+
dynamic_image_size=args.dynamic,
68+
use_thumbnail=use_thumbnail,
69+
max_num=args.max_num
5570
)
5671

5772
outputs = []
@@ -61,16 +76,14 @@ def evaluate_chat_model():
6176
num_beams=args.num_beams,
6277
max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
6378
min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
64-
length_penalty=1,
65-
# repetition_penalty=1.5,
6679
do_sample=True if args.temperature > 0 else False,
6780
temperature=args.temperature,
6881
)
6982
pred = model.chat(
7083
tokenizer=tokenizer,
7184
pixel_values=pixel_values,
7285
question=question,
73-
generation_config=generation_config,
86+
generation_config=generation_config
7487
)
7588
outputs.append({
7689
'question_id': question_id,
@@ -100,6 +113,8 @@ def evaluate_chat_model():
100113
parser.add_argument('--temperature', type=float, default=0.0)
101114
parser.add_argument('--out-dir', type=str, default='results')
102115
parser.add_argument('--seed', type=int, default=0)
116+
parser.add_argument('--dynamic', action='store_true')
117+
parser.add_argument('--max-num', type=int, default=6)
103118
args = parser.parse_args()
104119

105120
if not os.path.exists(args.out_dir):
@@ -109,30 +124,23 @@ def evaluate_chat_model():
109124
print('datasets:', args.datasets)
110125
assert args.batch_size == 1, 'Only batch size 1 is supported'
111126

112-
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
113-
114-
if 'qllama' in args.checkpoint.lower():
115-
from internvl.model.internvl_chat_with_qllama import InternVLChatModel
116-
model = InternVLChatModel.from_pretrained(
117-
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
118-
image_size = model.internvl.config.force_image_size or model.config.internvl_config.vision_config.image_size
119-
pad2square = model.config.pad2square
120-
else:
121-
from internvl.model.internvl_chat import InternVLChatModel
122-
model = InternVLChatModel.from_pretrained(
123-
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
124-
image_size = model.config.force_image_size or model.config.vision_config.image_size
125-
pad2square = model.config.pad2square
127+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
128+
model = InternVLChatModel.from_pretrained(
129+
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).cuda().eval()
130+
image_size = model.config.force_image_size or model.config.vision_config.image_size
131+
use_thumbnail = model.config.use_thumbnail
126132

127133
total_params = sum(p.numel() for p in model.parameters()) / 1e9
128-
if total_params > 30:
134+
if total_params > 20:
129135
args.num_beams = 1
130136
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
131137
else:
132138
print(f'[test] total_params: {total_params}B')
133139
print(f'[test] image_size: {image_size}')
134-
print(f'[test] pad2square: {pad2square}')
135140
print(f'[test] template: {model.config.template}')
141+
print(f'[test] dynamic_image_size: {args.dynamic}')
142+
print(f'[test] use_thumbnail: {use_thumbnail}')
143+
print(f'[test] max_num: {args.max_num}')
136144

137145
model_id = '_'.join(args.checkpoint.split('/')[-2:])
138146
evaluate_chat_model()

0 commit comments

Comments
 (0)