diff --git a/llmc/eval/eval_vlm.py b/llmc/eval/eval_vlm.py index af519c61c..60e82aa5f 100644 --- a/llmc/eval/eval_vlm.py +++ b/llmc/eval/eval_vlm.py @@ -18,6 +18,7 @@ def __init__(self, config): ], 'VLM eval only support MME dataset now.' self.eval_dataset_path = self.eval_config['path'] self.eval_bs = self.eval_config['bs'] + self.output_include_input = True if self.dataset == 'MME': self.img_qas = self.load_mme() self.patch_datasets(config.model.type) @@ -39,6 +40,10 @@ def patch_datasets(self, model_type): for idx in range(len(self.img_qas)): if '\n' not in self.img_qas[idx]['question']: self.img_qas[idx]['question'] = '\n' + self.img_qas[idx]['question'] + if model_type == 'InternVL2': + self.output_include_input = False + elif model_type == 'Llava': + self.output_include_input = True def eval(self, model, tokenizer): vlm_model = model.vlm_model @@ -64,9 +69,14 @@ def eval(self, model, tokenizer): for k, v in inputs.items() } outputs = vlm_model.generate(**inputs, max_new_tokens=32, do_sample=False) - gen_txts = vlm_tokenizer.batch_decode( - outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True - ) + if self.output_include_input: + gen_txts = vlm_tokenizer.batch_decode( + outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True + ) + else: + gen_txts = vlm_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) for n in range(len(batch_samples)): result = batch_samples[n].copy() result.update({'gen_txt': gen_txts[n]})