Skip to content

Commit 8fff25b

Browse files
committed
Support MMVP
1 parent 21baf4c commit 8fff25b

File tree

5 files changed

+446
-3
lines changed

5 files changed

+446
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## News🚀🚀🚀
44

5+
- `2024/02/04`: [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) achieves 44.67% on [MMVP](https://github.com/tsb0601/MMVP), higher than GPT-4V!
56
- `2024/01/27`: We release 448 resolution model, achieving 76.6 on MMBench dev, see [here](https://github.com/OpenGVLab/InternVL/tree/main/internvl_chat#-evaluation-chinese-models).
67
- `2024/01/24`: InternVL-Chat-V1.1 is released, it supports Chinese and has stronger OCR capability, see [here](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) or try our [demo](https://internvl.opengvlab.com/).
78
- `2024/01/16`: We release our [customized mmcv/mmsegmentation/mmdetection code](https://github.com/OpenGVLab/InternVL-MMDetSeg), integrated with DeepSpeed, which can be used for training large-scale object detection and semantic segmentation models.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import argparse
2+
import csv
3+
import os
4+
5+
import torch
6+
from PIL import Image
7+
from tqdm import tqdm
8+
from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor
9+
10+
11+
def benchmark_model(model_name, benchmark_dir, device='cuda'):
12+
# model_path = '/mnt/petrelfs/share_data/wangwenhai/llm/internvl_14b_224px'
13+
model_path = 'OpenGVLab/InternVL-14B-224px'
14+
model = AutoModel.from_pretrained(
15+
model_path,
16+
torch_dtype=torch.float16,
17+
low_cpu_mem_usage=True,
18+
trust_remote_code=True).cuda().eval()
19+
preprocess = CLIPImageProcessor.from_pretrained(model_path)
20+
tokenizer = AutoTokenizer.from_pretrained(
21+
model_path, use_fast=False, add_eos_token=True)
22+
tokenizer.pad_token_id = 0 # set pad_token_id to 0
23+
image_dir = os.path.join(benchmark_dir, 'MLLM_VLM Images')
24+
csv_file = os.path.join(benchmark_dir, 'Questions.csv')
25+
26+
csv_outfile = open('output.csv', 'w', newline='')
27+
csv_writer = csv.writer(csv_outfile)
28+
csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header
29+
30+
categories = [
31+
'Orientation and Direction', 'Presence of Specific Features',
32+
'State and Condition', 'Quantity and Count',
33+
'Positional and Relational Context', 'Color and Appearance',
34+
'Structural Characteristics', 'Texts',
35+
'Viewpoint and Perspective'
36+
]
37+
38+
pair_accuracies = {category: 0 for category in categories}
39+
num_pairs = 0
40+
41+
with open(csv_file, 'r') as f:
42+
reader = csv.reader(f)
43+
next(reader) # skip header
44+
for i, row in tqdm(enumerate(reader)):
45+
qid1, qtype1, statement1 = row
46+
47+
# Get next row for the pair
48+
row = next(reader, None)
49+
if not row:
50+
break
51+
qid2, qtype2, statement2 = row
52+
53+
qid1, qid2 = int(qid1), int(qid2)
54+
55+
img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg'))
56+
img1 = img1.resize((224, 224))
57+
img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg'))
58+
img2 = img2.resize((224, 224))
59+
60+
prefix = 'summarize:'
61+
# text1 = prefix + 'a photo of ' + statement1
62+
# text2 = prefix + 'a photo of ' + statement2
63+
text1 = prefix + statement1
64+
text2 = prefix + statement2
65+
66+
text1 = tokenizer(text1, return_tensors='pt', max_length=80,
67+
truncation=True, padding='max_length').input_ids.cuda()
68+
text2 = tokenizer(text2, return_tensors='pt', max_length=80,
69+
truncation=True, padding='max_length').input_ids.cuda()
70+
71+
img1 = preprocess(images=img1, return_tensors='pt').pixel_values.to(torch.float16).cuda()
72+
img2 = preprocess(images=img2, return_tensors='pt').pixel_values.to(torch.float16).cuda()
73+
imgs = torch.cat((img1, img2), dim=0)
74+
75+
with torch.no_grad():
76+
logits_per_image1, logits_per_text1 = model(image=imgs, text=text1, mode=model_name)
77+
logits_per_image2, logits_per_text2 = model(image=imgs, text=text2, mode=model_name)
78+
79+
probs1 = logits_per_text1.float().softmax(dim=-1).cpu().numpy()
80+
probs2 = logits_per_text2.float().softmax(dim=-1).cpu().numpy()
81+
82+
img1_score1 = probs1[0][0]
83+
img1_score2 = probs2[0][0]
84+
85+
pred1 = 'img1' if img1_score1 > 0.5 else 'img2'
86+
pred2 = 'img1' if img1_score2 > 0.5 else 'img2'
87+
88+
gt1 = 'img1' if qid1 % 2 == 1 else 'img2'
89+
gt2 = 'img1' if qid2 % 2 == 1 else 'img2'
90+
91+
csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2])
92+
93+
current_category = categories[num_pairs // 15]
94+
if pred1 == gt1 and pred2 == gt2:
95+
pair_accuracies[current_category] += 1
96+
num_pairs += 1
97+
98+
csv_outfile.close()
99+
100+
# Calculate percentage accuracies
101+
for category in pair_accuracies:
102+
pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100
103+
104+
return pair_accuracies
105+
106+
107+
parser = argparse.ArgumentParser(description='Process a directory path.')
108+
109+
# Adding an argument for the directory path
110+
parser.add_argument('--directory', type=str, help='The path to the directory')
111+
112+
# Parsing the arguments
113+
args = parser.parse_args()
114+
115+
# InternVL models
116+
models = ['InternVL-C', 'InternVL-G']
117+
118+
results = {f'{model}': benchmark_model(model, args.directory) for model in models}
119+
120+
print(results)
121+
122+
# Convert results to format suitable for star plot
123+
categories = results[list(results.keys())[0]].keys()
124+
print(f'categories: {categories}')
125+
data = {'Categories': list(categories)}
126+
print(f'data: {data}')
127+
for model in list(results.keys()):
128+
data[model] = [results[model][category] for category in categories]
129+
print(f'data: {data}')

internvl_chat/README.md

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ Coming Soon
110110

111111
**MultiModal Benchmark**
112112

113-
| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE |
114-
| --------------------------------------------------------------------------------- | -------------- | ---------------------- | ------------------------- | ---- |
115-
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 1672.3 / 341.1 | 76.6 / 75.4 | 71.5 / 70.1 | 87.2 |
113+
| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE | MMVP |
114+
| --------------------------------------------------------------------------------- | -------------- | ---------------------- | ------------------------- | ---- | ---- |
115+
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 1672.3 / 341.1 | 76.6 / 75.4 | 71.5 / 70.1 | 87.2 | 44.7 |
116116

117117
| model | MMMU<sub>val/test</sub> | CMMMU<sub>val/test</sub> | Tiny<sub>LVLM</sub> | LLaVA<sub>bench</sub> | MM-Vet |
118118
| --------------------------------------------------------------------------------- | ----------------------- | ------------------------ | ------------------- | --------------------- | ------ |
@@ -284,6 +284,13 @@ data
284284
│ └── Sociology
285285
├── mm-vet
286286
│ └── images/
287+
├── MMVP
288+
│ ├── MMVP Images/
289+
│ ├── Questions.csv
290+
│ └── Questions.xlsx
291+
├── MMVP_VLM
292+
│ ├── MLLM_VLM Images/
293+
│ └── Questions.csv
287294
```
288295

289296
</details>
@@ -974,3 +981,27 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mmvet
974981
```
975982

976983
</details>
984+
985+
#### [MMVP](https://github.com/tsb0601/MMVP)
986+
987+
<details>
988+
<summary>Data Preparation</summary>
989+
990+
```bash
991+
cd data
992+
git lfs install
993+
git clone https://huggingface.co/datasets/MMVP/MMVP
994+
git clone https://huggingface.co/datasets/MMVP/MMVP_VLM
995+
cd ..
996+
```
997+
998+
</details>
999+
1000+
<details>
1001+
<summary>Evaluation</summary>
1002+
1003+
```bash
1004+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mmvp
1005+
```
1006+
1007+
</details>

0 commit comments

Comments
 (0)