1515from llmc.compression.quantization import *
1616from llmc.compression.sparsification import *
1717from llmc.data import BaseDataset, BaseTokenizer
18- from llmc.eval import AccuracyEval, PerplexityEval, TokenConsistencyEval
18+ from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval,
19+ VLMEval)
1920from llmc.models import *
2021from llmc.utils import (check_config, mkdirs, print_important_package_version,
2122 seed_all, update_autoawq_quant_config,
@@ -48,6 +49,9 @@ def main(config):
4849 if config.eval.type == 'acc':
4950 acc_eval = AccuracyEval(eval_config)
5051 eval_list.append(acc_eval)
52+ elif config.eval.type == 'img_txt':
53+ acc_eval = VLMEval(eval_config)
54+ eval_list.append(acc_eval)
5155 else:
5256 ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
5357 eval_list.append(ppl_eval)
@@ -57,6 +61,10 @@ def main(config):
5761 for acc_eval in eval_list:
5862 acc = acc_eval.eval(model)
5963 logger.info(f'{config.eval.name} acc : {acc}')
64+ elif config.eval.type == 'img_txt':
65+ for vlm_eval in eval_list:
66+ results = vlm_eval.eval(model, tokenizer)
67+ logger.info(f'{config.eval.name} results : {results}')
6068 else:
6169 for ppl_eval in eval_list:
6270 ppl = ppl_eval.eval(model)
@@ -76,18 +84,6 @@ def main(config):
7684 dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
7785 calib_data, padding_mask = dataset.get_calib_dataset()
7886 padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
79- if config.calib.type == 'img_txt':
80- model.collect_first_encoder_block_input(calib_data, padding_mask,
81- padding_side, config.calib.type)
82- blockwise_opt = ALGO_REGISTRY[config.quant.method](
83- model,
84- config.quant,
85- model.get_first_block_input(),
86- model.get_padding_mask(),
87- config,
88- 'vision'
89- )
90- blockwise_opt.run_block_loop()
9187 model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type)
9288 del calib_data
9389 gc.collect()
@@ -118,6 +114,10 @@ def main(config):
118114 for acc_eval in eval_list:
119115 acc = acc_eval.eval(model)
120116 logger.info(f'{config.eval.name} acc : {acc}')
117+ elif config.eval.type == 'img_txt':
118+ for vlm_eval in eval_list:
119+ results = vlm_eval.eval(model, tokenizer)
120+ logger.info(f'{config.eval.name} results : {results}')
121121 else:
122122 for ppl_eval in eval_list:
123123 ppl = ppl_eval.eval(model)
@@ -142,6 +142,10 @@ def main(config):
142142 for acc_eval in eval_list:
143143 acc = acc_eval.eval(model)
144144 logger.info(f'{config.eval.name} acc : {acc}')
145+ elif config.eval.type == 'img_txt':
146+ for vlm_eval in eval_list:
147+ results = vlm_eval.eval(model, tokenizer)
148+ logger.info(f'{config.eval.name} results : {results}')
145149 else:
146150 for ppl_eval in eval_list:
147151 ppl = ppl_eval.eval(model)
0 commit comments