Skip to content

Commit 0196082

Browse files
committed
fix: improve spatialviz utils quality
- Fix FileExistsError -> FileNotFoundError (correct exception type) - Replace print() with eval_logger for consistent logging - Add type hints to all functions - Fix missing comma bug in final_answer_patterns list - Remove redundant image_path = image_path assignment - Initialize op variable to prevent potential UnboundLocalError - Break long prompt string for readability (88 char line limit)
1 parent 76d573d commit 0196082

File tree

1 file changed

+72
-34
lines changed

1 file changed

+72
-34
lines changed

lmms_eval/tasks/spatialviz/utils.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import re
33
from collections import defaultdict
44
from pathlib import Path
5+
from typing import Any, Dict, List
56

67
import yaml
78
from huggingface_hub import snapshot_download
9+
from loguru import logger as eval_logger
810
from PIL import Image
911

1012
with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
@@ -23,7 +25,7 @@
2325
)
2426

2527

26-
def spatialviz_doc_to_visual(doc):
28+
def spatialviz_doc_to_visual(doc: Dict[str, Any]) -> List[Image.Image]:
2729
visual = []
2830

2931
category = doc["Category"]
@@ -33,16 +35,21 @@ def spatialviz_doc_to_visual(doc):
3335
image_path = f"{cache_dir}/{category}/{task}/{level}/{image_id}.png"
3436

3537
if os.path.exists(image_path):
36-
image_path = image_path
3738
visual.append(Image.open(image_path).convert("RGB"))
3839
else:
39-
raise FileExistsError(f"video path:{image_path} does not exist.")
40+
raise FileNotFoundError(f"image path: {image_path} does not exist.")
4041
return visual
4142

4243

43-
def spatialviz_doc_to_text(doc):
44+
def spatialviz_doc_to_text(doc: Dict[str, Any]) -> str:
4445
ops = ["A", "B", "C", "D"]
45-
prompt = "You should first provide a reasoning process, then provide a single option(A, B, C or D) as the final answer. The reasoning process and the answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think>reasoning process</think>, <answer>answer</answer>.\n"
46+
prompt = (
47+
"You should first provide a reasoning process, then provide a single "
48+
"option(A, B, C or D) as the final answer. The reasoning process and "
49+
"the answer are enclosed within <think></think> and <answer></answer> "
50+
"tags, respectively, i.e., <think>reasoning process</think>, "
51+
"<answer>answer</answer>.\n"
52+
)
4653
question = doc["Question"]
4754
choices = doc["Choices"]
4855
choice_text = ""
@@ -53,7 +60,9 @@ def spatialviz_doc_to_text(doc):
5360
return text
5461

5562

56-
def spatialviz_process_results(doc, results):
63+
def spatialviz_process_results(
64+
doc: Dict[str, Any], results: List[str]
65+
) -> Dict[str, Dict[str, Any]]:
5766
key_name = "spatialviz_score"
5867
grounded_output = doc["Answer"]
5968
response = results[0]
@@ -63,14 +72,28 @@ def spatialviz_process_results(doc, results):
6372

6473
think_match = re.search(think_pattern, response, re.DOTALL)
6574
answer_match = re.search(answer_pattern, response, re.DOTALL)
75+
76+
op: List[str] = []
6677
if think_match and answer_match:
6778
final_answer = answer_match.group(1).strip()
6879
pred_answer = final_answer.split(".")[0]
6980
op = re.findall(r"[A-D]", pred_answer)
70-
7181
else:
72-
print("No match for think/answer \n")
73-
final_answer_patterns = ["<answer>", "Answer:", "Final answer", "final answer", "Final Answer", "the answer is", "The answer is", "correct answer", "Correct answer", "Correct Answer", "答案" "correct path"]
82+
eval_logger.debug("No match for think/answer tags in response")
83+
final_answer_patterns = [
84+
"<answer>",
85+
"Answer:",
86+
"Final answer",
87+
"final answer",
88+
"Final Answer",
89+
"the answer is",
90+
"The answer is",
91+
"correct answer",
92+
"Correct answer",
93+
"Correct Answer",
94+
"答案",
95+
"correct path",
96+
]
7497
if len(response) == 1:
7598
op = re.findall(r"[A-D]", response)
7699
else:
@@ -88,14 +111,23 @@ def spatialviz_process_results(doc, results):
88111
is_correct = False
89112

90113
query = spatialviz_doc_to_text(doc)
91-
spatialviz_submission = {"id": doc["Image_id"], "query": query, "gt_content": grounded_output, "pred": response, "category": doc["Category"], "task": doc["Task"], "level": doc["Level"], "is_correct": is_correct}
114+
spatialviz_submission = {
115+
"id": doc["Image_id"],
116+
"query": query,
117+
"gt_content": grounded_output,
118+
"pred": response,
119+
"category": doc["Category"],
120+
"task": doc["Task"],
121+
"level": doc["Level"],
122+
"is_correct": is_correct,
123+
}
92124
return {key_name: spatialviz_submission}
93125

94126

95-
def spatialviz_aggregate_results(results):
96-
task_to_eval_samples = defaultdict(list)
97-
category_to_eval_samples = defaultdict(list)
98-
key_to_eval_samples = defaultdict(list)
127+
def spatialviz_aggregate_results(results: List[Dict[str, Any]]) -> float:
128+
task_to_eval_samples: Dict[str, List[int]] = defaultdict(list)
129+
category_to_eval_samples: Dict[str, List[int]] = defaultdict(list)
130+
key_to_eval_samples: Dict[str, List[int]] = defaultdict(list)
99131
total_samples = len(results)
100132
total_correct = 0
101133

@@ -117,29 +149,35 @@ def spatialviz_aggregate_results(results):
117149
key_to_eval_samples[key].append(0)
118150

119151
accuracy = total_correct / total_samples if total_samples > 0 else 0
120-
task_accuracies = {task: sum(scores) / len(scores) for task, scores in task_to_eval_samples.items()}
121-
category_accuracies = {category: sum(scores) / len(scores) for category, scores in category_to_eval_samples.items()}
122-
key_accuracies = {key: sum(scores) / len(scores) for key, scores in key_to_eval_samples.items()}
123-
print(f"{'Total Samples':<20}: {total_samples}")
124-
print(f"{'Total Correct':<20}: {total_correct}")
125-
print(f"{'Overall Accuracy':<20}: {accuracy:.4f}")
126-
print()
127-
128-
print(f"{'Per-Task Accuracy':<40}")
129-
print("-" * 40)
152+
task_accuracies = {
153+
task: sum(scores) / len(scores) for task, scores in task_to_eval_samples.items()
154+
}
155+
category_accuracies = {
156+
category: sum(scores) / len(scores)
157+
for category, scores in category_to_eval_samples.items()
158+
}
159+
key_accuracies = {
160+
key: sum(scores) / len(scores) for key, scores in key_to_eval_samples.items()
161+
}
162+
163+
eval_logger.info(f"{'Total Samples':<20}: {total_samples}")
164+
eval_logger.info(f"{'Total Correct':<20}: {total_correct}")
165+
eval_logger.info(f"{'Overall Accuracy':<20}: {accuracy:.4f}")
166+
167+
eval_logger.info(f"{'Per-Task Accuracy':<40}")
168+
eval_logger.info("-" * 40)
130169
for task, acc in task_accuracies.items():
131-
print(f"{task:<20}: {acc:.4f}")
132-
print()
170+
eval_logger.info(f"{task:<20}: {acc:.4f}")
133171

134-
print(f"{'Per-Category Accuracy':<40}")
135-
print("-" * 40)
172+
eval_logger.info(f"{'Per-Category Accuracy':<40}")
173+
eval_logger.info("-" * 40)
136174
for category, acc in category_accuracies.items():
137-
print(f"{category:<20}: {acc:.4f}")
138-
print("=" * 40)
175+
eval_logger.info(f"{category:<20}: {acc:.4f}")
176+
eval_logger.info("=" * 40)
139177

140-
print(f"{'Per-Key Accuracy':<40}")
141-
print("-" * 40)
178+
eval_logger.info(f"{'Per-Key Accuracy':<40}")
179+
eval_logger.info("-" * 40)
142180
for key, acc in key_accuracies.items():
143-
print(f"{key:<20}: {acc:.4f}")
144-
print()
181+
eval_logger.info(f"{key:<20}: {acc:.4f}")
182+
145183
return accuracy

0 commit comments

Comments
 (0)