diff --git a/lmms_eval/models/simple/qwen2_vl.py b/lmms_eval/models/simple/qwen2_vl.py index bc5ae3a06..863a9f2e0 100755 --- a/lmms_eval/models/simple/qwen2_vl.py +++ b/lmms_eval/models/simple/qwen2_vl.py @@ -188,6 +188,16 @@ def _collate(x): # Import utils here if flatten is moved import lmms_eval.utils as utils + def _ensure_list(v): + if v is None: + return [] + if isinstance(v, list): + # incase [[img]] + if len(v) == 1 and isinstance(v[0], list): + return v[0] + return v + return [v] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling @@ -202,13 +212,10 @@ def _collate(x): # TODO: Clarify the behavior of doc_to_visual for documents without visual info. # The current logic might incorrectly discard all visuals if one doc lacks them. # Ensure flatten is appropriate here based on doc_to_visual's return type. - visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] - if None in visual_list: # This check might need refinement - # If a mix of visual/non-visual is possible, this needs careful handling - # Currently sets all visuals to empty if any doc returns None - visual_list = [] - else: - visual_list = self.flatten(visual_list) # Assumes doc_to_visual returns list of lists + visuals_per_doc = [] + for fn, ids in zip(doc_to_visual, doc_id): + v = fn(self.task_dict[task][split][ids]) + visuals_per_doc.append(_ensure_list(v)) gen_kwargs = all_gen_kwargs[0] if all_gen_kwargs else {} @@ -249,7 +256,7 @@ def _collate(x): # Needs careful review based on doc_to_visual output structure # For simplicity, assuming visual_list contains all visuals for the batch for now # A more robust approach might map visuals back to their original context index. - relevant_visuals = visual_list # Placeholder: needs logic to get visuals for context 'i' + relevant_visuals = visuals_per_doc[i] # Placeholder: needs logic to get visuals for context 'i' for visual in relevant_visuals: if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file