diff --git a/llmc/__main__.py b/llmc/__main__.py index f00fdd8d6..b0aabed0b 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -15,7 +15,8 @@ from llmc.compression.quantization import * from llmc.compression.sparsification import * from llmc.data import BaseDataset, BaseTokenizer -from llmc.eval import AccuracyEval, PerplexityEval, TokenConsistencyEval +from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval, + VLMEval) from llmc.models import * from llmc.utils import (check_config, mkdirs, print_important_package_version, seed_all, update_autoawq_quant_config, @@ -48,6 +49,9 @@ def main(config): if config.eval.type == 'acc': acc_eval = AccuracyEval(eval_config) eval_list.append(acc_eval) + elif config.eval.type == 'img_txt': + acc_eval = VLMEval(eval_config) + eval_list.append(acc_eval) else: ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config) eval_list.append(ppl_eval) @@ -57,6 +61,10 @@ def main(config): for acc_eval in eval_list: acc = acc_eval.eval(model) logger.info(f'{config.eval.name} acc : {acc}') + elif config.eval.type == 'img_txt': + for vlm_eval in eval_list: + results = vlm_eval.eval(model, tokenizer) + logger.info(f'{config.eval.name} results : {results}') else: for ppl_eval in eval_list: ppl = ppl_eval.eval(model) @@ -76,18 +84,6 @@ def main(config): dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process) calib_data, padding_mask = dataset.get_calib_dataset() padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None) - if config.calib.type == 'img_txt': - model.collect_first_encoder_block_input(calib_data, padding_mask, - padding_side, config.calib.type) - blockwise_opt = ALGO_REGISTRY[config.quant.method]( - model, - config.quant, - model.get_first_block_input(), - model.get_padding_mask(), - config, - 'vision' - ) - blockwise_opt.run_block_loop() model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type) del calib_data gc.collect() @@ -118,6 +114,10 @@ def main(config): for acc_eval in eval_list: acc = acc_eval.eval(model) logger.info(f'{config.eval.name} acc : {acc}') + elif config.eval.type == 'img_txt': + for vlm_eval in eval_list: + results = vlm_eval.eval(model, tokenizer) + logger.info(f'{config.eval.name} results : {results}') else: for ppl_eval in eval_list: ppl = ppl_eval.eval(model) @@ -142,6 +142,10 @@ def main(config): for acc_eval in eval_list: acc = acc_eval.eval(model) logger.info(f'{config.eval.name} acc : {acc}') + elif config.eval.type == 'img_txt': + for vlm_eval in eval_list: + results = vlm_eval.eval(model, tokenizer) + logger.info(f'{config.eval.name} results : {results}') else: for ppl_eval in eval_list: ppl = ppl_eval.eval(model) diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 6bb0bf273..1daef540e 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -27,8 +27,8 @@ class BaseBlockwiseQuantization(BlockwiseOpt): - def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): - super().__init__(model, quant_config, input, padding_mask, config, modality) + def __init__(self, model, quant_config, input, padding_mask, config): + super().__init__(model, quant_config, input, padding_mask, config) self.set_quant_config() def w_qdq(self, module, wquantizer): @@ -63,6 +63,9 @@ def a_qdq(self, act, module, aquantizer, input_index=0): else: return aquantizer.fake_quant_act_dynamic(act) + def logit(self, x): + return torch.log(x / (1 - x)) + def get_replacement_params(self, mode='fake_quant', w_only=False, name=None): params_dict = {} if mode == 'fake_quant': @@ -268,6 +271,9 @@ def set_quant_config(self): self.intermediate_size = self.model.model_config.intermediate_size self.fp32_had = special_config.get('fp32_had', False) + self.quant_objects = self.quant_config.get('quant_objects', ['language']) + logger.info(f'self.quant_objects : {self.quant_objects}') + def replace_rotate_linears(self, block): for n, m in block.named_modules(): if isinstance(m, nn.Linear) and ('down_proj' in n @@ -433,8 +439,7 @@ def run(self, block, input_feat, handles): def block_transform(self, block, input_feat, block_kwargs): logger.info(f'Start transform the {self.block_idx}-th block') - subsets = self.model.get_subsets_in_block(block) \ - if self.modality == 'language' else self.model.get_encoder_subsets_in_block(block) + subsets = self.model.get_subsets_in_block(block) if self.act_static: self.register_non_linear_qparams(block, input_feat) @@ -804,12 +809,22 @@ def deploy(self, quant_format, keep_device=False): ) module = module_mapping[quant_format] - self.model.replace_module_all( - module, - self.get_replacement_params(mode=quant_format, w_only=self.w_only), - keep_device=keep_device - ) - self.set_non_linear_mode(quant_format, self.model.model, False) + if 'vision' in self.quant_objects: + self.model.replace_vision_module_all( + module, + self.get_replacement_params(mode=quant_format, w_only=self.w_only), + keep_device=keep_device + ) + if 'language' in self.quant_objects: + self.model.replace_language_module_all( + module, + self.get_replacement_params(mode=quant_format, w_only=self.w_only), + keep_device=keep_device + ) + self.set_non_linear_mode(quant_format, self.model.model, False) + + if hasattr(self.model, 'vlm_model'): + logger.info(f'Now, the vlm_model is: {self.model.vlm_model}') logger.info(f'-- deploy_{quant_format}_model done --') diff --git a/llmc/compression/quantization/llmint8.py b/llmc/compression/quantization/llmint8.py index 29209f63a..116b8843e 100644 --- a/llmc/compression/quantization/llmint8.py +++ b/llmc/compression/quantization/llmint8.py @@ -66,7 +66,7 @@ def deploy(self, quant_format): logger.info(f'-- deploy_{quant_format}_model start --') logger.info(f'quant_config : {self.quant_config}') - self.model.replace_module_all( + self.model.replace_language_module_all( FakeQuantLinear, self.get_replacement_params( mode='fake_quant', w_only=self.w_only, name=None diff --git a/llmc/eval/__init__.py b/llmc/eval/__init__.py index 7fd4c3b60..435148c80 100644 --- a/llmc/eval/__init__.py +++ b/llmc/eval/__init__.py @@ -1,3 +1,4 @@ from .eval_acc import AccuracyEval from .eval_ppl import PerplexityEval from .eval_token_consist import TokenConsistencyEval +from .eval_vlm import VLMEval diff --git a/llmc/eval/eval_vlm.py b/llmc/eval/eval_vlm.py new file mode 100644 index 000000000..a87a80e3d --- /dev/null +++ b/llmc/eval/eval_vlm.py @@ -0,0 +1,254 @@ +import gc +import json +import os +from collections import defaultdict + +import torch +from loguru import logger +from sklearn.metrics import (accuracy_score, confusion_matrix, precision_score, + recall_score) + + +class VLMEval: + def __init__(self, eval_config): + self.eval_config = eval_config + self.dataset = eval_config['name'] + assert self.dataset in [ + 'MME', + ], 'VLM eval only support MME dataset now.' + self.eval_dataset_path = eval_config['path'] + self.eval_bs = eval_config['bs'] + if self.dataset == 'MME': + self.img_qas = self.load_mme() + logger.info('VLMEval load dataset done.') + + def load_mme(self): + img_qa_json = os.path.join(self.eval_dataset_path, 'img_qa.json') + fp = open(img_qa_json) + img_qas = json.load(fp) + for idx in range(len(img_qas)): + img_qas[idx]['img'] = os.path.join( + self.eval_dataset_path, img_qas[idx]['img'] + ) + return img_qas + + def eval(self, model, tokenizer): + vlm_model = model.vlm_model + vlm_tokenizer = tokenizer.get_tokenizer() + vlm_model.cuda() + results = [] + logger.info(f'len(self.img_qas): {len(self.img_qas)}') + logger.info(f'eval_bs: {self.eval_bs}') + for idx in range(0, len(self.img_qas), self.eval_bs): + logger.info( + f'index : {(idx + 1) // self.eval_bs}/{len(self.img_qas) // self.eval_bs}' + ) + start = idx + end = min(idx + self.eval_bs, len(self.img_qas)) + batch_samples = self.img_qas[start:end] + inputs = model.batch_process(batch_samples) + inputs = { + k: ( + v.to(next(vlm_model.parameters()).device) + if torch.is_tensor(v) + else v + ) + for k, v in inputs.items() + } + outputs = vlm_model.generate(**inputs, max_new_tokens=32, do_sample=False) + gen_txts = vlm_tokenizer.batch_decode( + outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True + ) + for n in range(len(batch_samples)): + result = batch_samples[n].copy() + result.update({'gen_txt': gen_txts[n]}) + results.append(result) + if self.dataset == 'MME': + eval_class = MME() + vlm_score = eval_class(results) + + vlm_model.cpu() + gc.collect() + torch.cuda.empty_cache() + + return vlm_score + + +class MME: + def __init__(self): + self.eval_type_dict = { + 'Perception': [ + 'existence', + 'count', + 'position', + 'color', + 'posters', + 'celebrity', + 'scene', + 'landmark', + 'artwork', + 'OCR', + ], + 'Cognition': [ + 'commonsense_reasoning', + 'numerical_calculation', + 'text_translation', + 'code_reasoning', + ], + } + + def divide_chunks(self, lines, n=2): + # looping till length lines + for i in range(0, len(lines), n): + yield lines[i: i + n] + + return + + def parse_pred_ans(self, pred_ans): + pred_label = None + if pred_ans in ['yes', 'no']: + pred_label = pred_ans + else: + prefix_pred_ans = pred_ans[:4] + + if 'yes' in prefix_pred_ans: + pred_label = 'yes' + elif 'no' in prefix_pred_ans: + pred_label = 'no' + else: + pred_label = 'other' + + return pred_label + + def compute_metric(self, gts, preds): + assert len(gts) == len(preds) + + label_map = { + 'yes': 1, + 'no': 0, + 'other': -1, + } + + gts = [label_map[x] for x in gts] + preds = [label_map[x] for x in preds] + + acc = accuracy_score(gts, preds) + + clean_gts = [] + clean_preds = [] + other_num = 0 + for gt, pred in zip(gts, preds): + if pred == -1: + other_num += 1 + continue + clean_gts.append(gt) + clean_preds.append(pred) + + conf_mat = confusion_matrix(clean_gts, clean_preds, labels=[1, 0]) + precision = precision_score(clean_gts, clean_preds, average='binary') + recall = recall_score(clean_gts, clean_preds, average='binary') + tp, fn = conf_mat[0] + fp, tn = conf_mat[1] + + metric_dict = dict() + metric_dict = { + 'TP': tp, + 'FN': fn, + 'TN': tn, + 'FP': fp, + 'precision': precision, + 'recall': recall, + 'other_num': other_num, + 'acc': acc, + } + + return metric_dict + + def get_lines(self, results): + lines_dict = defaultdict(list) + for res in results: + task_name = res['img'].split('/')[-2] + assert ( + task_name in self.eval_type_dict['Perception'] + or task_name in self.eval_type_dict['Cognition'] + ) + txt = ( + res['img'].split('/')[-1] + + '\t' + + res['question'] + + '\t' + + res['answer'] + + '\t' + + res['gen_txt'] + + '\n' + ) + lines_dict[task_name].append(txt) + return lines_dict + + def __call__(self, results): + lines_dict = self.get_lines(results) + mme_scores = {} + for eval_type, task_name_list in self.eval_type_dict.items(): + mme_scores[eval_type] = {} + + scores = 0 + task_score_dict = dict() + + for task_name in task_name_list: + lines = lines_dict[task_name] + chunk_lines = list( + self.divide_chunks(lines) + ) # one image corresponds to two questions + + img_num = len(chunk_lines) + task_other_ans_num = 0 + task_score = 0 + acc_plus_correct_num = 0 + gts = [] + preds = [] + + for img_items in chunk_lines: + assert len(img_items) == 2 + img_correct_num = 0 + + for img_item in img_items: + img_name, question, gt_ans, pred_ans = img_item.split('\t') + + gt_ans = gt_ans.lower() + pred_ans = pred_ans.lower() + + assert gt_ans in ['yes', 'no'] # gt can only be yes or no. + + pred_ans = self.parse_pred_ans(pred_ans) + assert pred_ans in ['yes', 'no', 'other'] + + gts.append(gt_ans) + preds.append(pred_ans) + + if gt_ans == pred_ans: + img_correct_num += 1 + + if pred_ans not in ['yes', 'no']: + task_other_ans_num += 1 + + if img_correct_num == 2: + acc_plus_correct_num += 1 + + # cal TP precision acc, etc. + metric_dict = self.compute_metric(gts, preds) + acc_plus = acc_plus_correct_num / img_num + metric_dict['acc_plus'] = acc_plus + + for k, v in metric_dict.items(): + if k in ['acc', 'acc_plus']: + task_score += v * 100 + + task_score_dict[task_name] = task_score + + scores += task_score + + mme_scores[eval_type]['total_score'] = scores + for task_name, score in task_score_dict.items(): + mme_scores[eval_type][task_name] = score + + return json.dumps(mme_scores, ensure_ascii=False, indent=4) diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 11993d750..97e5b0fad 100644 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -28,7 +28,6 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): self.build_model() self.model.eval() self.find_blocks() - self.find_encoder_blocks() self.find_embed_layers() self.find_block_name() self.add_layernorms_class() @@ -37,20 +36,14 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): def find_blocks(self): pass - def find_encoder_blocks(self): - pass - - def get_encoder_catcher(self, first_block_input): - pass - def find_block_name(self): pass def get_model(self): return self.model - def get_blocks(self, modality='language'): - return self.blocks if modality == 'language' else self.encoder_blocks + def get_blocks(self): + return self.blocks @abstractmethod def find_embed_layers(self): @@ -193,43 +186,6 @@ def collect_first_block_input(self, calib_data, padding_mask=None, padding_side= self.blocks[0] = self.blocks[0].cpu() self.move_embed_to_device('cpu') - @torch.no_grad() - def collect_first_encoder_block_input(self, calib_data, padding_mask=None, padding_side=None, data_type='txt'): # noqa - first_block_input = defaultdict(list) - - Catcher = self.get_encoder_catcher(first_block_input) - - self.move_embed_to_device('cuda') - if data_type == 'img_txt': - self.vision_model = self.vision_model.to('cuda') - self.projector = self.projector.to('cuda') - self.encoder_blocks[0] = self.encoder_blocks[0].cuda() - self.encoder_blocks[0] = Catcher(self.encoder_blocks[0]) - - for data in calib_data: - if isinstance(data, BatchFeature): - data = data.to(next(self.model.parameters()).device) - else: - data = { - k: (v.to(next(self.model.parameters()).device) if torch.is_tensor(v) else v) - for k, v in data.items() - } - try: - if data_type in ['txt', 'img']: - self.model(**data) - elif data_type == 'img_txt': - self.vlm_model.generate(**data, max_new_tokens=128, do_sample=False) - except ValueError: - pass - self.first_block_input = first_block_input - self.padding_mask = None - if data_type == 'img_txt': - self.vision_model = self.vision_model.cpu() - self.projector = self.projector.cpu() - self.encoder_blocks[0] = self.encoder_blocks[0].module - self.encoder_blocks[0] = self.encoder_blocks[0].cpu() - self.move_embed_to_device('cpu') - def get_one_pad_setting(self, padding_side, length): if padding_side == 'left': return [0, length] @@ -260,6 +216,13 @@ def get_block_linears(self, block): if isinstance(m, tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)) } + def get_all_linears(self, module): + return { + name: m + for name, m in module.named_modules() + if isinstance(m, tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)) + } + def get_extra_modules(self, block): return {} @@ -324,10 +287,30 @@ def set_mix_bits_params_dict(self, block_idx, name, params_dict): params_mix_dict['a_qdq'] = None return params_mix_dict - def replace_modality_module_all(self, module, blocks, params_dict, keep_device=False): - for block_idx in range(len(blocks)): - logger.info(f'Replace block index: {block_idx}/{len(blocks)}') - block = blocks[block_idx] + def replace_vision_module_all(self, module, params_dict, keep_device=False): + vision_model_linears = self.get_block_linears(self.vision_model) + for name, m in vision_model_linears.items(): + M = module.new(m, **params_dict) + + name_tmp = name.rsplit('.', 1) + if len(name_tmp) == 2: + parent_name = name_tmp[0] + parent = self.vision_model.get_submodule(parent_name) + child_name = name_tmp[1] + elif len(name_tmp) == 1: + parent = self.vision_model + child_name = name_tmp[0] + + setattr(parent, child_name, M) + + gc.collect() + torch.cuda.empty_cache() + logger.info(f'The Replaced vision_model: {self.vision_model}') + + def replace_language_module_all(self, module, params_dict, keep_device=False): + for block_idx in range(len(self.blocks)): + logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}') + block = self.blocks[block_idx] if keep_device: self.replace_module_block(module, block, block_idx, params_dict) else: @@ -335,14 +318,6 @@ def replace_modality_module_all(self, module, blocks, params_dict, keep_device=F self.replace_module_block(module, block, block_idx, params_dict) block = block.cpu() - def replace_module_all(self, module, params_dict, keep_device=False): - if hasattr(self, 'encoder_blocks'): - logger.info('start replace vision blocks') - self.replace_modality_module_all(module, self.encoder_blocks, params_dict, keep_device) - - logger.info('start replace language blocks') - self.replace_modality_module_all(module, self.blocks, params_dict, keep_device) - gc.collect() torch.cuda.empty_cache() logger.info(f'The Replaced model: {self.model}') diff --git a/tools/convert_mme.py b/tools/convert_mme.py new file mode 100644 index 000000000..5f8009860 --- /dev/null +++ b/tools/convert_mme.py @@ -0,0 +1,37 @@ +import argparse +import json +import os + +from loguru import logger + + +def convert_mme(mme_path): + img_qa_list = [] + for root, dirs, files in os.walk(mme_path, topdown=False): + for name in files: + if name.endswith('.jpg') or name.endswith('.png'): + img_path = os.path.join(root, name) + img_path_tmp = img_path.split('/') + img = os.path.join(img_path_tmp[-3], img_path_tmp[-2], name) + txt_path = img_path[:-3] + 'txt' + fp = open(txt_path, 'r') + lines = fp.readlines() + for line in lines: + question, answer = line.split('\t') + img_qa = { + 'img': img.strip(), + 'question': question.strip(), + 'answer': answer.strip() + } + img_qa_list.append(img_qa) + fp = open('img_qa.json', 'w') + json.dump(img_qa_list, fp, indent=4) + logger.info('img_qa.json is done. You need to move it to MME file folder.') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--mme_path', type=str, required=True) + args = parser.parse_args() + logger.info(f'args : {args}') + convert_mme(args.mme_path) diff --git a/tools/quant_analysis.py b/tools/quant_analysis.py index b5daae629..1accb3594 100644 --- a/tools/quant_analysis.py +++ b/tools/quant_analysis.py @@ -441,7 +441,7 @@ def a_qdq(act, module=None): params_dict = {} params_dict['w_qdq'] = wquanter.fake_quant_weight_dynamic params_dict['a_qdq'] = None if args.w_only else a_qdq - t_model.replace_module_all(FakeQuantLinear, params_dict) + t_model.replace_language_module_all(FakeQuantLinear, params_dict) with torch.no_grad(): for i in tqdm(range(len(model.blocks))):