22import re
33from collections import defaultdict
44from pathlib import Path
5+ from typing import Any , Dict , List
56
67import yaml
78from huggingface_hub import snapshot_download
9+ from loguru import logger as eval_logger
810from PIL import Image
911
1012with open (Path (__file__ ).parent / "_default_template_yaml" , "r" ) as f :
2325)
2426
2527
26- def spatialviz_doc_to_visual (doc ) :
28+ def spatialviz_doc_to_visual (doc : Dict [ str , Any ]) -> List [ Image . Image ] :
2729 visual = []
2830
2931 category = doc ["Category" ]
@@ -33,16 +35,21 @@ def spatialviz_doc_to_visual(doc):
3335 image_path = f"{ cache_dir } /{ category } /{ task } /{ level } /{ image_id } .png"
3436
3537 if os .path .exists (image_path ):
36- image_path = image_path
3738 visual .append (Image .open (image_path ).convert ("RGB" ))
3839 else :
39- raise FileExistsError (f"video path:{ image_path } does not exist." )
40+ raise FileNotFoundError (f"image path: { image_path } does not exist." )
4041 return visual
4142
4243
43- def spatialviz_doc_to_text (doc ) :
44+ def spatialviz_doc_to_text (doc : Dict [ str , Any ]) -> str :
4445 ops = ["A" , "B" , "C" , "D" ]
45- prompt = "You should first provide a reasoning process, then provide a single option(A, B, C or D) as the final answer. The reasoning process and the answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think>reasoning process</think>, <answer>answer</answer>.\n "
46+ prompt = (
47+ "You should first provide a reasoning process, then provide a single "
48+ "option(A, B, C or D) as the final answer. The reasoning process and "
49+ "the answer are enclosed within <think></think> and <answer></answer> "
50+ "tags, respectively, i.e., <think>reasoning process</think>, "
51+ "<answer>answer</answer>.\n "
52+ )
4653 question = doc ["Question" ]
4754 choices = doc ["Choices" ]
4855 choice_text = ""
@@ -53,7 +60,9 @@ def spatialviz_doc_to_text(doc):
5360 return text
5461
5562
56- def spatialviz_process_results (doc , results ):
63+ def spatialviz_process_results (
64+ doc : Dict [str , Any ], results : List [str ]
65+ ) -> Dict [str , Dict [str , Any ]]:
5766 key_name = "spatialviz_score"
5867 grounded_output = doc ["Answer" ]
5968 response = results [0 ]
@@ -63,14 +72,28 @@ def spatialviz_process_results(doc, results):
6372
6473 think_match = re .search (think_pattern , response , re .DOTALL )
6574 answer_match = re .search (answer_pattern , response , re .DOTALL )
75+
76+ op : List [str ] = []
6677 if think_match and answer_match :
6778 final_answer = answer_match .group (1 ).strip ()
6879 pred_answer = final_answer .split ("." )[0 ]
6980 op = re .findall (r"[A-D]" , pred_answer )
70-
7181 else :
72- print ("No match for think/answer \n " )
73- final_answer_patterns = ["<answer>" , "Answer:" , "Final answer" , "final answer" , "Final Answer" , "the answer is" , "The answer is" , "correct answer" , "Correct answer" , "Correct Answer" , "答案" "correct path" ]
82+ eval_logger .debug ("No match for think/answer tags in response" )
83+ final_answer_patterns = [
84+ "<answer>" ,
85+ "Answer:" ,
86+ "Final answer" ,
87+ "final answer" ,
88+ "Final Answer" ,
89+ "the answer is" ,
90+ "The answer is" ,
91+ "correct answer" ,
92+ "Correct answer" ,
93+ "Correct Answer" ,
94+ "答案" ,
95+ "correct path" ,
96+ ]
7497 if len (response ) == 1 :
7598 op = re .findall (r"[A-D]" , response )
7699 else :
@@ -88,14 +111,23 @@ def spatialviz_process_results(doc, results):
88111 is_correct = False
89112
90113 query = spatialviz_doc_to_text (doc )
91- spatialviz_submission = {"id" : doc ["Image_id" ], "query" : query , "gt_content" : grounded_output , "pred" : response , "category" : doc ["Category" ], "task" : doc ["Task" ], "level" : doc ["Level" ], "is_correct" : is_correct }
114+ spatialviz_submission = {
115+ "id" : doc ["Image_id" ],
116+ "query" : query ,
117+ "gt_content" : grounded_output ,
118+ "pred" : response ,
119+ "category" : doc ["Category" ],
120+ "task" : doc ["Task" ],
121+ "level" : doc ["Level" ],
122+ "is_correct" : is_correct ,
123+ }
92124 return {key_name : spatialviz_submission }
93125
94126
95- def spatialviz_aggregate_results (results ) :
96- task_to_eval_samples = defaultdict (list )
97- category_to_eval_samples = defaultdict (list )
98- key_to_eval_samples = defaultdict (list )
127+ def spatialviz_aggregate_results (results : List [ Dict [ str , Any ]]) -> float :
128+ task_to_eval_samples : Dict [ str , List [ int ]] = defaultdict (list )
129+ category_to_eval_samples : Dict [ str , List [ int ]] = defaultdict (list )
130+ key_to_eval_samples : Dict [ str , List [ int ]] = defaultdict (list )
99131 total_samples = len (results )
100132 total_correct = 0
101133
@@ -117,29 +149,35 @@ def spatialviz_aggregate_results(results):
117149 key_to_eval_samples [key ].append (0 )
118150
119151 accuracy = total_correct / total_samples if total_samples > 0 else 0
120- task_accuracies = {task : sum (scores ) / len (scores ) for task , scores in task_to_eval_samples .items ()}
121- category_accuracies = {category : sum (scores ) / len (scores ) for category , scores in category_to_eval_samples .items ()}
122- key_accuracies = {key : sum (scores ) / len (scores ) for key , scores in key_to_eval_samples .items ()}
123- print (f"{ 'Total Samples' :<20} : { total_samples } " )
124- print (f"{ 'Total Correct' :<20} : { total_correct } " )
125- print (f"{ 'Overall Accuracy' :<20} : { accuracy :.4f} " )
126- print ()
127-
128- print (f"{ 'Per-Task Accuracy' :<40} " )
129- print ("-" * 40 )
152+ task_accuracies = {
153+ task : sum (scores ) / len (scores ) for task , scores in task_to_eval_samples .items ()
154+ }
155+ category_accuracies = {
156+ category : sum (scores ) / len (scores )
157+ for category , scores in category_to_eval_samples .items ()
158+ }
159+ key_accuracies = {
160+ key : sum (scores ) / len (scores ) for key , scores in key_to_eval_samples .items ()
161+ }
162+
163+ eval_logger .info (f"{ 'Total Samples' :<20} : { total_samples } " )
164+ eval_logger .info (f"{ 'Total Correct' :<20} : { total_correct } " )
165+ eval_logger .info (f"{ 'Overall Accuracy' :<20} : { accuracy :.4f} " )
166+
167+ eval_logger .info (f"{ 'Per-Task Accuracy' :<40} " )
168+ eval_logger .info ("-" * 40 )
130169 for task , acc in task_accuracies .items ():
131- print (f"{ task :<20} : { acc :.4f} " )
132- print ()
170+ eval_logger .info (f"{ task :<20} : { acc :.4f} " )
133171
134- print (f"{ 'Per-Category Accuracy' :<40} " )
135- print ("-" * 40 )
172+ eval_logger . info (f"{ 'Per-Category Accuracy' :<40} " )
173+ eval_logger . info ("-" * 40 )
136174 for category , acc in category_accuracies .items ():
137- print (f"{ category :<20} : { acc :.4f} " )
138- print ("=" * 40 )
175+ eval_logger . info (f"{ category :<20} : { acc :.4f} " )
176+ eval_logger . info ("=" * 40 )
139177
140- print (f"{ 'Per-Key Accuracy' :<40} " )
141- print ("-" * 40 )
178+ eval_logger . info (f"{ 'Per-Key Accuracy' :<40} " )
179+ eval_logger . info ("-" * 40 )
142180 for key , acc in key_accuracies .items ():
143- print (f"{ key :<20} : { acc :.4f} " )
144- print ()
181+ eval_logger . info (f"{ key :<20} : { acc :.4f} " )
182+
145183 return accuracy
0 commit comments