Skip to content

Commit c79490b

Browse files
authored
fix: filter multimodal content from log samples while preserving metadata (#962)
* fix: improve spatialviz utils quality - Fix FileExistsError -> FileNotFoundError (correct exception type) - Replace print() with eval_logger for consistent logging - Add type hints to all functions - Fix missing comma bug in final_answer_patterns list - Remove redundant image_path = image_path assignment - Initialize op variable to prevent potential UnboundLocalError - Break long prompt string for readability (88 char line limit) * style: apply black formatting * fix: filter multimodal content from log samples while preserving metadata When using --log_samples, the previous implementation either saved all fields (causing serialization issues with images/audio) or filtered based on key names (missing useful metadata like image_id, image_path). This fix introduces is_multimodal_content() that detects actual multimodal data types (PIL.Image, numpy arrays, torch tensors, HuggingFace audio/image dicts) while preserving all scalar metadata fields for dataset traceability. Github-Issue:#943
1 parent 76d573d commit c79490b

File tree

3 files changed

+90
-31
lines changed

3 files changed

+90
-31
lines changed

lmms_eval/evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
get_git_commit_hash,
4242
handle_non_serializable,
4343
hash_string,
44+
is_multimodal_content,
4445
make_table,
4546
positional_deprecated,
4647
run_task_tests,
@@ -562,7 +563,8 @@ def evaluate(
562563
target = task.doc_to_target(doc)
563564
saved_doc = {}
564565
for key, value in doc.items():
565-
saved_doc[key] = value
566+
if not is_multimodal_content(value):
567+
saved_doc[key] = value
566568
filtered_arguments = []
567569
for req in requests:
568570
# check if req.args is a list of tuples, and each item in the list is a serializable object

lmms_eval/tasks/spatialviz/utils.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import re
33
from collections import defaultdict
44
from pathlib import Path
5+
from typing import Any, Dict, List
56

67
import yaml
78
from huggingface_hub import snapshot_download
9+
from loguru import logger as eval_logger
810
from PIL import Image
911

1012
with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
@@ -23,7 +25,7 @@
2325
)
2426

2527

26-
def spatialviz_doc_to_visual(doc):
28+
def spatialviz_doc_to_visual(doc: Dict[str, Any]) -> List[Image.Image]:
2729
visual = []
2830

2931
category = doc["Category"]
@@ -33,16 +35,21 @@ def spatialviz_doc_to_visual(doc):
3335
image_path = f"{cache_dir}/{category}/{task}/{level}/{image_id}.png"
3436

3537
if os.path.exists(image_path):
36-
image_path = image_path
3738
visual.append(Image.open(image_path).convert("RGB"))
3839
else:
39-
raise FileExistsError(f"video path:{image_path} does not exist.")
40+
raise FileNotFoundError(f"image path: {image_path} does not exist.")
4041
return visual
4142

4243

43-
def spatialviz_doc_to_text(doc):
44+
def spatialviz_doc_to_text(doc: Dict[str, Any]) -> str:
4445
ops = ["A", "B", "C", "D"]
45-
prompt = "You should first provide a reasoning process, then provide a single option(A, B, C or D) as the final answer. The reasoning process and the answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think>reasoning process</think>, <answer>answer</answer>.\n"
46+
prompt = (
47+
"You should first provide a reasoning process, then provide a single "
48+
"option(A, B, C or D) as the final answer. The reasoning process and "
49+
"the answer are enclosed within <think></think> and <answer></answer> "
50+
"tags, respectively, i.e., <think>reasoning process</think>, "
51+
"<answer>answer</answer>.\n"
52+
)
4653
question = doc["Question"]
4754
choices = doc["Choices"]
4855
choice_text = ""
@@ -53,7 +60,7 @@ def spatialviz_doc_to_text(doc):
5360
return text
5461

5562

56-
def spatialviz_process_results(doc, results):
63+
def spatialviz_process_results(doc: Dict[str, Any], results: List[str]) -> Dict[str, Dict[str, Any]]:
5764
key_name = "spatialviz_score"
5865
grounded_output = doc["Answer"]
5966
response = results[0]
@@ -63,14 +70,28 @@ def spatialviz_process_results(doc, results):
6370

6471
think_match = re.search(think_pattern, response, re.DOTALL)
6572
answer_match = re.search(answer_pattern, response, re.DOTALL)
73+
74+
op: List[str] = []
6675
if think_match and answer_match:
6776
final_answer = answer_match.group(1).strip()
6877
pred_answer = final_answer.split(".")[0]
6978
op = re.findall(r"[A-D]", pred_answer)
70-
7179
else:
72-
print("No match for think/answer \n")
73-
final_answer_patterns = ["<answer>", "Answer:", "Final answer", "final answer", "Final Answer", "the answer is", "The answer is", "correct answer", "Correct answer", "Correct Answer", "答案" "correct path"]
80+
eval_logger.debug("No match for think/answer tags in response")
81+
final_answer_patterns = [
82+
"<answer>",
83+
"Answer:",
84+
"Final answer",
85+
"final answer",
86+
"Final Answer",
87+
"the answer is",
88+
"The answer is",
89+
"correct answer",
90+
"Correct answer",
91+
"Correct Answer",
92+
"答案",
93+
"correct path",
94+
]
7495
if len(response) == 1:
7596
op = re.findall(r"[A-D]", response)
7697
else:
@@ -88,14 +109,23 @@ def spatialviz_process_results(doc, results):
88109
is_correct = False
89110

90111
query = spatialviz_doc_to_text(doc)
91-
spatialviz_submission = {"id": doc["Image_id"], "query": query, "gt_content": grounded_output, "pred": response, "category": doc["Category"], "task": doc["Task"], "level": doc["Level"], "is_correct": is_correct}
112+
spatialviz_submission = {
113+
"id": doc["Image_id"],
114+
"query": query,
115+
"gt_content": grounded_output,
116+
"pred": response,
117+
"category": doc["Category"],
118+
"task": doc["Task"],
119+
"level": doc["Level"],
120+
"is_correct": is_correct,
121+
}
92122
return {key_name: spatialviz_submission}
93123

94124

95-
def spatialviz_aggregate_results(results):
96-
task_to_eval_samples = defaultdict(list)
97-
category_to_eval_samples = defaultdict(list)
98-
key_to_eval_samples = defaultdict(list)
125+
def spatialviz_aggregate_results(results: List[Dict[str, Any]]) -> float:
126+
task_to_eval_samples: Dict[str, List[int]] = defaultdict(list)
127+
category_to_eval_samples: Dict[str, List[int]] = defaultdict(list)
128+
key_to_eval_samples: Dict[str, List[int]] = defaultdict(list)
99129
total_samples = len(results)
100130
total_correct = 0
101131

@@ -120,26 +150,25 @@ def spatialviz_aggregate_results(results):
120150
task_accuracies = {task: sum(scores) / len(scores) for task, scores in task_to_eval_samples.items()}
121151
category_accuracies = {category: sum(scores) / len(scores) for category, scores in category_to_eval_samples.items()}
122152
key_accuracies = {key: sum(scores) / len(scores) for key, scores in key_to_eval_samples.items()}
123-
print(f"{'Total Samples':<20}: {total_samples}")
124-
print(f"{'Total Correct':<20}: {total_correct}")
125-
print(f"{'Overall Accuracy':<20}: {accuracy:.4f}")
126-
print()
127153

128-
print(f"{'Per-Task Accuracy':<40}")
129-
print("-" * 40)
154+
eval_logger.info(f"{'Total Samples':<20}: {total_samples}")
155+
eval_logger.info(f"{'Total Correct':<20}: {total_correct}")
156+
eval_logger.info(f"{'Overall Accuracy':<20}: {accuracy:.4f}")
157+
158+
eval_logger.info(f"{'Per-Task Accuracy':<40}")
159+
eval_logger.info("-" * 40)
130160
for task, acc in task_accuracies.items():
131-
print(f"{task:<20}: {acc:.4f}")
132-
print()
161+
eval_logger.info(f"{task:<20}: {acc:.4f}")
133162

134-
print(f"{'Per-Category Accuracy':<40}")
135-
print("-" * 40)
163+
eval_logger.info(f"{'Per-Category Accuracy':<40}")
164+
eval_logger.info("-" * 40)
136165
for category, acc in category_accuracies.items():
137-
print(f"{category:<20}: {acc:.4f}")
138-
print("=" * 40)
166+
eval_logger.info(f"{category:<20}: {acc:.4f}")
167+
eval_logger.info("=" * 40)
139168

140-
print(f"{'Per-Key Accuracy':<40}")
141-
print("-" * 40)
169+
eval_logger.info(f"{'Per-Key Accuracy':<40}")
170+
eval_logger.info("-" * 40)
142171
for key, acc in key_accuracies.items():
143-
print(f"{key:<20}: {acc:.4f}")
144-
print()
172+
eval_logger.info(f"{key:<20}: {acc:.4f}")
173+
145174
return accuracy

lmms_eval/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,34 @@ def handle_non_serializable(o):
102102
return str(o)
103103

104104

105+
def is_multimodal_content(value: Any) -> bool:
106+
"""
107+
Check if a value is multimodal content (image, audio, video) that should
108+
not be serialized to log files.
109+
110+
Returns True for:
111+
- PIL.Image objects
112+
- numpy arrays (typically image/audio data)
113+
- bytes (binary data)
114+
- torch tensors
115+
- dicts with 'array' key (HuggingFace audio format)
116+
- dicts with 'bytes' key (HuggingFace image format)
117+
"""
118+
if isinstance(value, (bytes, bytearray, np.ndarray, torch.Tensor)):
119+
return True
120+
if isinstance(value, dict):
121+
if "array" in value or "bytes" in value:
122+
return True
123+
try:
124+
from PIL import Image
125+
126+
if isinstance(value, Image.Image):
127+
return True
128+
except ImportError:
129+
pass
130+
return False
131+
132+
105133
def sanitize_list(sub):
106134
"""
107135
Takes possible nested list and recursively converts all inner component to strings

0 commit comments

Comments
 (0)