Skip to content

Commit c8b347f

Browse files
committed
support MME eval & naive quant for VLM
1 parent aecd6c4 commit c8b347f

File tree

8 files changed

+358
-10
lines changed

8 files changed

+358
-10
lines changed

llmc/__main__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from llmc.compression.quantization import *
1616
from llmc.compression.sparsification import *
1717
from llmc.data import BaseDataset, BaseTokenizer
18-
from llmc.eval import AccuracyEval, PerplexityEval, TokenConsistencyEval
18+
from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval,
19+
VLMEval)
1920
from llmc.models import *
2021
from llmc.utils import (check_config, mkdirs, print_important_package_version,
2122
seed_all, update_autoawq_quant_config,
@@ -48,6 +49,9 @@ def main(config):
4849
if config.eval.type == 'acc':
4950
acc_eval = AccuracyEval(eval_config)
5051
eval_list.append(acc_eval)
52+
elif config.eval.type == 'img_txt':
53+
acc_eval = VLMEval(eval_config)
54+
eval_list.append(acc_eval)
5155
else:
5256
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
5357
eval_list.append(ppl_eval)
@@ -57,6 +61,10 @@ def main(config):
5761
for acc_eval in eval_list:
5862
acc = acc_eval.eval(model)
5963
logger.info(f'{config.eval.name} acc : {acc}')
64+
elif config.eval.type == 'img_txt':
65+
for vlm_eval in eval_list:
66+
results = vlm_eval.eval(model, tokenizer)
67+
logger.info(f'{config.eval.name} results : {results}')
6068
else:
6169
for ppl_eval in eval_list:
6270
ppl = ppl_eval.eval(model)
@@ -106,6 +114,10 @@ def main(config):
106114
for acc_eval in eval_list:
107115
acc = acc_eval.eval(model)
108116
logger.info(f'{config.eval.name} acc : {acc}')
117+
elif config.eval.type == 'img_txt':
118+
for vlm_eval in eval_list:
119+
results = vlm_eval.eval(model, tokenizer)
120+
logger.info(f'{config.eval.name} results : {results}')
109121
else:
110122
for ppl_eval in eval_list:
111123
ppl = ppl_eval.eval(model)
@@ -130,6 +142,10 @@ def main(config):
130142
for acc_eval in eval_list:
131143
acc = acc_eval.eval(model)
132144
logger.info(f'{config.eval.name} acc : {acc}')
145+
elif config.eval.type == 'img_txt':
146+
for vlm_eval in eval_list:
147+
results = vlm_eval.eval(model, tokenizer)
148+
logger.info(f'{config.eval.name} results : {results}')
133149
else:
134150
for ppl_eval in eval_list:
135151
ppl = ppl_eval.eval(model)

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ def set_quant_config(self):
271271
self.intermediate_size = self.model.model_config.intermediate_size
272272
self.fp32_had = special_config.get('fp32_had', False)
273273

274+
self.quant_objects = self.quant_config.get('quant_objects', ['language'])
275+
logger.info(f'self.quant_objects : {self.quant_objects}')
276+
274277
def replace_rotate_linears(self, block):
275278
for n, m in block.named_modules():
276279
if isinstance(m, nn.Linear) and ('down_proj' in n
@@ -806,12 +809,22 @@ def deploy(self, quant_format, keep_device=False):
806809
)
807810

808811
module = module_mapping[quant_format]
809-
self.model.replace_module_all(
810-
module,
811-
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
812-
keep_device=keep_device
813-
)
814-
self.set_non_linear_mode(quant_format, self.model.model, False)
812+
if 'vision' in self.quant_objects:
813+
self.model.replace_vision_module_all(
814+
module,
815+
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
816+
keep_device=keep_device
817+
)
818+
if 'language' in self.quant_objects:
819+
self.model.replace_language_module_all(
820+
module,
821+
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
822+
keep_device=keep_device
823+
)
824+
self.set_non_linear_mode(quant_format, self.model.model, False)
825+
826+
if hasattr(self.model, 'vlm_model'):
827+
logger.info(f'Now, the vlm_model is: {self.model.vlm_model}')
815828

816829
logger.info(f'-- deploy_{quant_format}_model done --')
817830

llmc/compression/quantization/llmint8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def deploy(self, quant_format):
6666
logger.info(f'-- deploy_{quant_format}_model start --')
6767
logger.info(f'quant_config : {self.quant_config}')
6868

69-
self.model.replace_module_all(
69+
self.model.replace_language_module_all(
7070
FakeQuantLinear,
7171
self.get_replacement_params(
7272
mode='fake_quant', w_only=self.w_only, name=None

llmc/eval/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .eval_acc import AccuracyEval
22
from .eval_ppl import PerplexityEval
33
from .eval_token_consist import TokenConsistencyEval
4+
from .eval_vlm import VLMEval

llmc/eval/eval_vlm.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import gc
2+
import json
3+
import os
4+
from collections import defaultdict
5+
6+
import torch
7+
from loguru import logger
8+
from sklearn.metrics import (accuracy_score, confusion_matrix, precision_score,
9+
recall_score)
10+
11+
12+
class VLMEval:
13+
def __init__(self, eval_config):
14+
self.eval_config = eval_config
15+
self.dataset = eval_config['name']
16+
assert self.dataset in [
17+
'MME',
18+
], 'VLM eval only support MME dataset now.'
19+
self.eval_dataset_path = eval_config['path']
20+
self.eval_bs = eval_config['bs']
21+
if self.dataset == 'MME':
22+
self.img_qas = self.load_mme()
23+
logger.info('VLMEval load dataset done.')
24+
25+
def load_mme(self):
26+
img_qa_json = os.path.join(self.eval_dataset_path, 'img_qa.json')
27+
fp = open(img_qa_json)
28+
img_qas = json.load(fp)
29+
for idx in range(len(img_qas)):
30+
img_qas[idx]['img'] = os.path.join(
31+
self.eval_dataset_path, img_qas[idx]['img']
32+
)
33+
return img_qas
34+
35+
def eval(self, model, tokenizer):
36+
vlm_model = model.vlm_model
37+
vlm_tokenizer = tokenizer.get_tokenizer()
38+
vlm_model.cuda()
39+
results = []
40+
logger.info(f'len(self.img_qas): {len(self.img_qas)}')
41+
logger.info(f'eval_bs: {self.eval_bs}')
42+
for idx in range(0, len(self.img_qas), self.eval_bs):
43+
logger.info(
44+
f'index : {(idx + 1) // self.eval_bs}/{len(self.img_qas) // self.eval_bs}'
45+
)
46+
start = idx
47+
end = min(idx + self.eval_bs, len(self.img_qas))
48+
batch_samples = self.img_qas[start:end]
49+
inputs = model.batch_process(batch_samples)
50+
inputs = {
51+
k: (
52+
v.to(next(vlm_model.parameters()).device)
53+
if torch.is_tensor(v)
54+
else v
55+
)
56+
for k, v in inputs.items()
57+
}
58+
outputs = vlm_model.generate(**inputs, max_new_tokens=32, do_sample=False)
59+
gen_txts = vlm_tokenizer.batch_decode(
60+
outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True
61+
)
62+
for n in range(len(batch_samples)):
63+
result = batch_samples[n].copy()
64+
result.update({'gen_txt': gen_txts[n]})
65+
results.append(result)
66+
if self.dataset == 'MME':
67+
eval_class = MME()
68+
vlm_score = eval_class(results)
69+
70+
vlm_model.cpu()
71+
gc.collect()
72+
torch.cuda.empty_cache()
73+
74+
return vlm_score
75+
76+
77+
class MME:
78+
def __init__(self):
79+
self.eval_type_dict = {
80+
'Perception': [
81+
'existence',
82+
'count',
83+
'position',
84+
'color',
85+
'posters',
86+
'celebrity',
87+
'scene',
88+
'landmark',
89+
'artwork',
90+
'OCR',
91+
],
92+
'Cognition': [
93+
'commonsense_reasoning',
94+
'numerical_calculation',
95+
'text_translation',
96+
'code_reasoning',
97+
],
98+
}
99+
100+
def divide_chunks(self, lines, n=2):
101+
# looping till length lines
102+
for i in range(0, len(lines), n):
103+
yield lines[i: i + n]
104+
105+
return
106+
107+
def parse_pred_ans(self, pred_ans):
108+
pred_label = None
109+
if pred_ans in ['yes', 'no']:
110+
pred_label = pred_ans
111+
else:
112+
prefix_pred_ans = pred_ans[:4]
113+
114+
if 'yes' in prefix_pred_ans:
115+
pred_label = 'yes'
116+
elif 'no' in prefix_pred_ans:
117+
pred_label = 'no'
118+
else:
119+
pred_label = 'other'
120+
121+
return pred_label
122+
123+
def compute_metric(self, gts, preds):
124+
assert len(gts) == len(preds)
125+
126+
label_map = {
127+
'yes': 1,
128+
'no': 0,
129+
'other': -1,
130+
}
131+
132+
gts = [label_map[x] for x in gts]
133+
preds = [label_map[x] for x in preds]
134+
135+
acc = accuracy_score(gts, preds)
136+
137+
clean_gts = []
138+
clean_preds = []
139+
other_num = 0
140+
for gt, pred in zip(gts, preds):
141+
if pred == -1:
142+
other_num += 1
143+
continue
144+
clean_gts.append(gt)
145+
clean_preds.append(pred)
146+
147+
conf_mat = confusion_matrix(clean_gts, clean_preds, labels=[1, 0])
148+
precision = precision_score(clean_gts, clean_preds, average='binary')
149+
recall = recall_score(clean_gts, clean_preds, average='binary')
150+
tp, fn = conf_mat[0]
151+
fp, tn = conf_mat[1]
152+
153+
metric_dict = dict()
154+
metric_dict = {
155+
'TP': tp,
156+
'FN': fn,
157+
'TN': tn,
158+
'FP': fp,
159+
'precision': precision,
160+
'recall': recall,
161+
'other_num': other_num,
162+
'acc': acc,
163+
}
164+
165+
return metric_dict
166+
167+
def get_lines(self, results):
168+
lines_dict = defaultdict(list)
169+
for res in results:
170+
task_name = res['img'].split('/')[-2]
171+
assert (
172+
task_name in self.eval_type_dict['Perception']
173+
or task_name in self.eval_type_dict['Cognition']
174+
)
175+
txt = (
176+
res['img'].split('/')[-1]
177+
+ '\t'
178+
+ res['question']
179+
+ '\t'
180+
+ res['answer']
181+
+ '\t'
182+
+ res['gen_txt']
183+
+ '\n'
184+
)
185+
lines_dict[task_name].append(txt)
186+
return lines_dict
187+
188+
def __call__(self, results):
189+
lines_dict = self.get_lines(results)
190+
mme_scores = {}
191+
for eval_type, task_name_list in self.eval_type_dict.items():
192+
mme_scores[eval_type] = {}
193+
194+
scores = 0
195+
task_score_dict = dict()
196+
197+
for task_name in task_name_list:
198+
lines = lines_dict[task_name]
199+
chunk_lines = list(
200+
self.divide_chunks(lines)
201+
) # one image corresponds to two questions
202+
203+
img_num = len(chunk_lines)
204+
task_other_ans_num = 0
205+
task_score = 0
206+
acc_plus_correct_num = 0
207+
gts = []
208+
preds = []
209+
210+
for img_items in chunk_lines:
211+
assert len(img_items) == 2
212+
img_correct_num = 0
213+
214+
for img_item in img_items:
215+
img_name, question, gt_ans, pred_ans = img_item.split('\t')
216+
217+
gt_ans = gt_ans.lower()
218+
pred_ans = pred_ans.lower()
219+
220+
assert gt_ans in ['yes', 'no'] # gt can only be yes or no.
221+
222+
pred_ans = self.parse_pred_ans(pred_ans)
223+
assert pred_ans in ['yes', 'no', 'other']
224+
225+
gts.append(gt_ans)
226+
preds.append(pred_ans)
227+
228+
if gt_ans == pred_ans:
229+
img_correct_num += 1
230+
231+
if pred_ans not in ['yes', 'no']:
232+
task_other_ans_num += 1
233+
234+
if img_correct_num == 2:
235+
acc_plus_correct_num += 1
236+
237+
# cal TP precision acc, etc.
238+
metric_dict = self.compute_metric(gts, preds)
239+
acc_plus = acc_plus_correct_num / img_num
240+
metric_dict['acc_plus'] = acc_plus
241+
242+
for k, v in metric_dict.items():
243+
if k in ['acc', 'acc_plus']:
244+
task_score += v * 100
245+
246+
task_score_dict[task_name] = task_score
247+
248+
scores += task_score
249+
250+
mme_scores[eval_type]['total_score'] = scores
251+
for task_name, score in task_score_dict.items():
252+
mme_scores[eval_type][task_name] = score
253+
254+
return json.dumps(mme_scores, ensure_ascii=False, indent=4)

0 commit comments

Comments
 (0)