Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ def main(config):
else [config.eval.name]
)
for name in name_list:
eval_config = copy.deepcopy(config.eval)
eval_config.name = name
config_for_eval = copy.deepcopy(config)
config_for_eval.eval.name = name
if len(name_list) != 1: # eval multi datasets
eval_config.path = os.path.join(config.eval.path, name)
config_for_eval.eval.path = os.path.join(config.eval.path, name)
if config.eval.type == 'acc':
acc_eval = AccuracyEval(eval_config)
acc_eval = AccuracyEval(config_for_eval)
eval_list.append(acc_eval)
elif config.eval.type == 'img_txt':
acc_eval = VLMEval(eval_config)
acc_eval = VLMEval(config_for_eval)
eval_list.append(acc_eval)
else:
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), config_for_eval)
eval_list.append(ppl_eval)

if 'eval' in config and 'pretrain' in config.eval.eval_pos:
Expand Down Expand Up @@ -156,7 +156,7 @@ def main(config):
config.model.path, config.model.torch_dtype
)
token_consist_eval = TokenConsistencyEval(tokenizer.get_tokenizer(),
eval_config)
config_for_eval)
consistency_ratio = token_consist_eval.eval(model, org_model)
logger.info(f'Token consistency ratio: {consistency_ratio}')
del org_model
Expand Down
6 changes: 3 additions & 3 deletions llmc/eval/eval_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


class AccuracyEval:
def __init__(self, eval_config, batch_size=256, num_workers=8):
self.eval_config = eval_config
self.imagenet_root = eval_config['path']
def __init__(self, config, batch_size=256, num_workers=8):
self.eval_config = config.eval
self.imagenet_root = self.eval_config['path']
self.batch_size = batch_size
self.num_workers = num_workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down
3 changes: 2 additions & 1 deletion llmc/eval/eval_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@


class BaseEval:
def __init__(self, tokenizer, eval_cfg):
def __init__(self, tokenizer, config):
self.tokenizer = tokenizer
# eval_cfg
eval_cfg = config.eval
logger.info(f'eval_cfg : {eval_cfg}')
self.dataset = eval_cfg['name']
assert self.dataset in [
Expand Down
18 changes: 13 additions & 5 deletions llmc/eval/eval_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@


class VLMEval:
def __init__(self, eval_config):
self.eval_config = eval_config
self.dataset = eval_config['name']
def __init__(self, config):
self.eval_config = config.eval
self.dataset = self.eval_config['name']
assert self.dataset in [
'MME',
], 'VLM eval only support MME dataset now.'
self.eval_dataset_path = eval_config['path']
self.eval_bs = eval_config['bs']
self.eval_dataset_path = self.eval_config['path']
self.eval_bs = self.eval_config['bs']
if self.dataset == 'MME':
self.img_qas = self.load_mme()
self.patch_datasets(config.model.type)
logger.info('VLMEval load dataset done.')

def load_mme(self):
Expand All @@ -32,6 +33,13 @@ def load_mme(self):
)
return img_qas

def patch_datasets(self, model_type):
if self.dataset == 'MME':
if model_type == 'InternVL2':
for idx in range(len(self.img_qas)):
if '<image>\n' not in self.img_qas[idx]['question']:
self.img_qas[idx]['question'] = '<image>\n' + self.img_qas[idx]['question']

def eval(self, model, tokenizer):
vlm_model = model.vlm_model
vlm_tokenizer = tokenizer.get_tokenizer()
Expand Down
Loading