Skip to content

Commit cf0be26

Browse files
support MME eval & naive quant for VLM (#189)
1 parent ebee43b commit cf0be26

File tree

8 files changed

+369
-83
lines changed

8 files changed

+369
-83
lines changed

llmc/__main__.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from llmc.compression.quantization import *
1616
from llmc.compression.sparsification import *
1717
from llmc.data import BaseDataset, BaseTokenizer
18-
from llmc.eval import AccuracyEval, PerplexityEval, TokenConsistencyEval
18+
from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval,
19+
VLMEval)
1920
from llmc.models import *
2021
from llmc.utils import (check_config, mkdirs, print_important_package_version,
2122
seed_all, update_autoawq_quant_config,
@@ -48,6 +49,9 @@ def main(config):
4849
if config.eval.type == 'acc':
4950
acc_eval = AccuracyEval(eval_config)
5051
eval_list.append(acc_eval)
52+
elif config.eval.type == 'img_txt':
53+
acc_eval = VLMEval(eval_config)
54+
eval_list.append(acc_eval)
5155
else:
5256
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
5357
eval_list.append(ppl_eval)
@@ -57,6 +61,10 @@ def main(config):
5761
for acc_eval in eval_list:
5862
acc = acc_eval.eval(model)
5963
logger.info(f'{config.eval.name} acc : {acc}')
64+
elif config.eval.type == 'img_txt':
65+
for vlm_eval in eval_list:
66+
results = vlm_eval.eval(model, tokenizer)
67+
logger.info(f'{config.eval.name} results : {results}')
6068
else:
6169
for ppl_eval in eval_list:
6270
ppl = ppl_eval.eval(model)
@@ -76,18 +84,6 @@ def main(config):
7684
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
7785
calib_data, padding_mask = dataset.get_calib_dataset()
7886
padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
79-
if config.calib.type == 'img_txt':
80-
model.collect_first_encoder_block_input(calib_data, padding_mask,
81-
padding_side, config.calib.type)
82-
blockwise_opt = ALGO_REGISTRY[config.quant.method](
83-
model,
84-
config.quant,
85-
model.get_first_block_input(),
86-
model.get_padding_mask(),
87-
config,
88-
'vision'
89-
)
90-
blockwise_opt.run_block_loop()
9187
model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type)
9288
del calib_data
9389
gc.collect()
@@ -118,6 +114,10 @@ def main(config):
118114
for acc_eval in eval_list:
119115
acc = acc_eval.eval(model)
120116
logger.info(f'{config.eval.name} acc : {acc}')
117+
elif config.eval.type == 'img_txt':
118+
for vlm_eval in eval_list:
119+
results = vlm_eval.eval(model, tokenizer)
120+
logger.info(f'{config.eval.name} results : {results}')
121121
else:
122122
for ppl_eval in eval_list:
123123
ppl = ppl_eval.eval(model)
@@ -142,6 +142,10 @@ def main(config):
142142
for acc_eval in eval_list:
143143
acc = acc_eval.eval(model)
144144
logger.info(f'{config.eval.name} acc : {acc}')
145+
elif config.eval.type == 'img_txt':
146+
for vlm_eval in eval_list:
147+
results = vlm_eval.eval(model, tokenizer)
148+
logger.info(f'{config.eval.name} results : {results}')
145149
else:
146150
for ppl_eval in eval_list:
147151
ppl = ppl_eval.eval(model)

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828

2929
class BaseBlockwiseQuantization(BlockwiseOpt):
30-
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
31-
super().__init__(model, quant_config, input, padding_mask, config, modality)
30+
def __init__(self, model, quant_config, input, padding_mask, config):
31+
super().__init__(model, quant_config, input, padding_mask, config)
3232
self.set_quant_config()
3333

3434
def w_qdq(self, module, wquantizer):
@@ -63,6 +63,9 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
6363
else:
6464
return aquantizer.fake_quant_act_dynamic(act)
6565

66+
def logit(self, x):
67+
return torch.log(x / (1 - x))
68+
6669
def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
6770
params_dict = {}
6871
if mode == 'fake_quant':
@@ -268,6 +271,9 @@ def set_quant_config(self):
268271
self.intermediate_size = self.model.model_config.intermediate_size
269272
self.fp32_had = special_config.get('fp32_had', False)
270273

274+
self.quant_objects = self.quant_config.get('quant_objects', ['language'])
275+
logger.info(f'self.quant_objects : {self.quant_objects}')
276+
271277
def replace_rotate_linears(self, block):
272278
for n, m in block.named_modules():
273279
if isinstance(m, nn.Linear) and ('down_proj' in n
@@ -433,8 +439,7 @@ def run(self, block, input_feat, handles):
433439

434440
def block_transform(self, block, input_feat, block_kwargs):
435441
logger.info(f'Start transform the {self.block_idx}-th block')
436-
subsets = self.model.get_subsets_in_block(block) \
437-
if self.modality == 'language' else self.model.get_encoder_subsets_in_block(block)
442+
subsets = self.model.get_subsets_in_block(block)
438443

439444
if self.act_static:
440445
self.register_non_linear_qparams(block, input_feat)
@@ -804,12 +809,22 @@ def deploy(self, quant_format, keep_device=False):
804809
)
805810

806811
module = module_mapping[quant_format]
807-
self.model.replace_module_all(
808-
module,
809-
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
810-
keep_device=keep_device
811-
)
812-
self.set_non_linear_mode(quant_format, self.model.model, False)
812+
if 'vision' in self.quant_objects:
813+
self.model.replace_vision_module_all(
814+
module,
815+
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
816+
keep_device=keep_device
817+
)
818+
if 'language' in self.quant_objects:
819+
self.model.replace_language_module_all(
820+
module,
821+
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
822+
keep_device=keep_device
823+
)
824+
self.set_non_linear_mode(quant_format, self.model.model, False)
825+
826+
if hasattr(self.model, 'vlm_model'):
827+
logger.info(f'Now, the vlm_model is: {self.model.vlm_model}')
813828

814829
logger.info(f'-- deploy_{quant_format}_model done --')
815830

llmc/compression/quantization/llmint8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def deploy(self, quant_format):
6666
logger.info(f'-- deploy_{quant_format}_model start --')
6767
logger.info(f'quant_config : {self.quant_config}')
6868

69-
self.model.replace_module_all(
69+
self.model.replace_language_module_all(
7070
FakeQuantLinear,
7171
self.get_replacement_params(
7272
mode='fake_quant', w_only=self.w_only, name=None

llmc/eval/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .eval_acc import AccuracyEval
22
from .eval_ppl import PerplexityEval
33
from .eval_token_consist import TokenConsistencyEval
4+
from .eval_vlm import VLMEval

0 commit comments

Comments
 (0)