Skip to content

Commit 82a9936

Browse files
authored
Enable text-only evals for VLM models (#2999)
1 parent 9d29ef0 commit 82a9936

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

lm_eval/evaluator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,6 @@ def evaluate(
494494
raise ValueError(
495495
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
496496
)
497-
else:
498-
raise ValueError(
499-
f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
500-
)
501497
# end validation check
502498

503499
# Cache the limit arg.

lm_eval/models/hf_vlms.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def _batch_images(self, image_encs):
399399
return batched_imgs
400400

401401
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
402+
if requests and len(requests[0].args) < 3:
403+
# Fall back to non-multimodal generation.
404+
return super().loglikelihood_rolling(requests=requests)
402405
raise NotImplementedError(
403406
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
404407
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
@@ -407,6 +410,9 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
407410
def loglikelihood(
408411
self, requests: List[Instance], disable_tqdm: bool = False
409412
) -> List[Tuple[float, bool]]:
413+
if requests and len(requests[0].args) < 3:
414+
# Fall back to non-multimodal generation.
415+
return super().loglikelihood(requests=requests, disable_tqdm=disable_tqdm)
410416
raise NotImplementedError(
411417
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
412418
)
@@ -433,9 +439,11 @@ def loglikelihood(
433439
)
434440
)
435441

436-
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
442+
return self._multimodal_loglikelihood_tokens(
443+
new_reqs, disable_tqdm=disable_tqdm
444+
)
437445

438-
def _loglikelihood_tokens(
446+
def _multimodal_loglikelihood_tokens(
439447
self,
440448
requests: List[
441449
Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
@@ -624,7 +632,10 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
624632
def generate_until(
625633
self, requests: List[Instance], disable_tqdm: bool = False
626634
) -> List[str]:
627-
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs)
635+
if requests and len(requests[0].args) < 3:
636+
# Fall back to non-multimodal generation.
637+
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
638+
628639
res = []
629640

630641
def _collate(x):

lm_eval/models/huggingface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,10 @@ def _model_call(self, inps, attn_mask=None, labels=None):
890890
input_ids=inps, attention_mask=attn_mask, labels=labels
891891
).logits
892892
else:
893-
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
893+
assert self.AUTO_MODEL_CLASS in (
894+
transformers.AutoModelForCausalLM,
895+
transformers.AutoModelForVision2Seq,
896+
)
894897
return self.model(inps).logits
895898

896899
def _model_generate(self, context, max_length, stop, **generation_kwargs):

lm_eval/models/vllm_vlms.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def tok_batch_multimodal_encode(
106106
outputs.append(inputs)
107107
return outputs
108108

109-
def _model_generate(
109+
def _multimodal_model_generate(
110110
self,
111111
requests: List[List[dict]] = None,
112112
generate: bool = False,
@@ -218,7 +218,10 @@ def apply_chat_template(
218218
def generate_until(
219219
self, requests: List[Instance], disable_tqdm: bool = False
220220
) -> List[str]:
221-
# TODO: support text-only reqs
221+
if requests and len(requests[0].args) < 3:
222+
# Fall back to non-multimodal generation.
223+
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
224+
222225
res = []
223226

224227
def _collate(x):
@@ -293,7 +296,7 @@ def _collate(x):
293296
left_truncate_len=max_ctx_len,
294297
)
295298

296-
cont = self._model_generate(
299+
cont = self._multimodal_model_generate(
297300
inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
298301
)
299302

@@ -309,3 +312,12 @@ def _collate(x):
309312

310313
pbar.close()
311314
return res
315+
316+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
317+
if requests and len(requests[0].args) < 3:
318+
# Fall back to non-multimodal generation.
319+
return super().loglikelihood_rolling(requests=requests)
320+
raise NotImplementedError(
321+
"model type `vllm-vlm` does not support loglikelihood_rolling. Use 'vlm' model type for text-only loglikelihood_rolling tasks ",
322+
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
323+
)

0 commit comments

Comments
 (0)