diff --git a/llmc/__main__.py b/llmc/__main__.py index b0aabed0b..a2d0557a8 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -42,18 +42,18 @@ def main(config): else [config.eval.name] ) for name in name_list: - eval_config = copy.deepcopy(config.eval) - eval_config.name = name + config_for_eval = copy.deepcopy(config) + config_for_eval.eval.name = name if len(name_list) != 1: # eval multi datasets - eval_config.path = os.path.join(config.eval.path, name) + config_for_eval.eval.path = os.path.join(config.eval.path, name) if config.eval.type == 'acc': - acc_eval = AccuracyEval(eval_config) + acc_eval = AccuracyEval(config_for_eval) eval_list.append(acc_eval) elif config.eval.type == 'img_txt': - acc_eval = VLMEval(eval_config) + acc_eval = VLMEval(config_for_eval) eval_list.append(acc_eval) else: - ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config) + ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), config_for_eval) eval_list.append(ppl_eval) if 'eval' in config and 'pretrain' in config.eval.eval_pos: @@ -156,7 +156,7 @@ def main(config): config.model.path, config.model.torch_dtype ) token_consist_eval = TokenConsistencyEval(tokenizer.get_tokenizer(), - eval_config) + config_for_eval) consistency_ratio = token_consist_eval.eval(model, org_model) logger.info(f'Token consistency ratio: {consistency_ratio}') del org_model diff --git a/llmc/eval/eval_acc.py b/llmc/eval/eval_acc.py index 91f4aed59..26a855180 100644 --- a/llmc/eval/eval_acc.py +++ b/llmc/eval/eval_acc.py @@ -8,9 +8,9 @@ class AccuracyEval: - def __init__(self, eval_config, batch_size=256, num_workers=8): - self.eval_config = eval_config - self.imagenet_root = eval_config['path'] + def __init__(self, config, batch_size=256, num_workers=8): + self.eval_config = config.eval + self.imagenet_root = self.eval_config['path'] self.batch_size = batch_size self.num_workers = num_workers self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') diff --git a/llmc/eval/eval_base.py b/llmc/eval/eval_base.py index b25447cea..f34e76fe3 100644 --- a/llmc/eval/eval_base.py +++ b/llmc/eval/eval_base.py @@ -8,9 +8,10 @@ class BaseEval: - def __init__(self, tokenizer, eval_cfg): + def __init__(self, tokenizer, config): self.tokenizer = tokenizer # eval_cfg + eval_cfg = config.eval logger.info(f'eval_cfg : {eval_cfg}') self.dataset = eval_cfg['name'] assert self.dataset in [ diff --git a/llmc/eval/eval_vlm.py b/llmc/eval/eval_vlm.py index a87a80e3d..af519c61c 100644 --- a/llmc/eval/eval_vlm.py +++ b/llmc/eval/eval_vlm.py @@ -10,16 +10,17 @@ class VLMEval: - def __init__(self, eval_config): - self.eval_config = eval_config - self.dataset = eval_config['name'] + def __init__(self, config): + self.eval_config = config.eval + self.dataset = self.eval_config['name'] assert self.dataset in [ 'MME', ], 'VLM eval only support MME dataset now.' - self.eval_dataset_path = eval_config['path'] - self.eval_bs = eval_config['bs'] + self.eval_dataset_path = self.eval_config['path'] + self.eval_bs = self.eval_config['bs'] if self.dataset == 'MME': self.img_qas = self.load_mme() + self.patch_datasets(config.model.type) logger.info('VLMEval load dataset done.') def load_mme(self): @@ -32,6 +33,13 @@ def load_mme(self): ) return img_qas + def patch_datasets(self, model_type): + if self.dataset == 'MME': + if model_type == 'InternVL2': + for idx in range(len(self.img_qas)): + if '\n' not in self.img_qas[idx]['question']: + self.img_qas[idx]['question'] = '\n' + self.img_qas[idx]['question'] + def eval(self, model, tokenizer): vlm_model = model.vlm_model vlm_tokenizer = tokenizer.get_tokenizer()