Skip to content

Commit f12b79a

Browse files
authored
add mmsi-bench (#715)
1 parent ada79d6 commit f12b79a

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
dataset_path: RunsenXu/MMSI-Bench
2+
3+
task: "mmsi_bench"
4+
dataset_kwargs:
5+
token: True
6+
test_split: test
7+
output_type: generate_until
8+
doc_to_visual: !function utils.msr_doc_to_visual
9+
doc_to_text: !function utils.msr_doc_to_text
10+
doc_to_target: "answer"
11+
process_results: !function utils.msr_process_results
12+
13+
lmms_eval_specific_kwargs:
14+
default:
15+
pre_prompt: ""
16+
post_prompt: "\nAnswer with the option's letter from the given choices directly. Enclose the option's letter within ``."
17+
18+
19+
generation_kwargs:
20+
max_new_tokens: 2048
21+
temperature: 0
22+
do_sample: False
23+
24+
25+
metric_list:
26+
- metric: Positional Relationship (Obj.-Obj.)
27+
aggregation: !function utils.msr_aggregate_results
28+
higher_is_better: true
29+
- metric: Positional Relationship (Cam.-Obj.)
30+
aggregation: !function utils.msr_aggregate_results
31+
higher_is_better: true
32+
- metric: Positional Relationship (Cam.-Cam.)
33+
aggregation: !function utils.msr_aggregate_results
34+
higher_is_better: true
35+
- metric: Positional Relationship (Obj.-Reg.)
36+
aggregation: !function utils.msr_aggregate_results
37+
higher_is_better: true
38+
- metric: Positional Relationship (Cam.-Reg.)
39+
aggregation: !function utils.msr_aggregate_results
40+
higher_is_better: true
41+
- metric: Positional Relationship (Reg.-Reg.)
42+
aggregation: !function utils.msr_aggregate_results
43+
higher_is_better: true
44+
- metric: Attribute (Meas.)
45+
aggregation: !function utils.msr_aggregate_results
46+
higher_is_better: true
47+
- metric: Attribute (Appr.)
48+
aggregation: !function utils.msr_aggregate_results
49+
higher_is_better: true
50+
- metric: Motion (Obj.)
51+
aggregation: !function utils.msr_aggregate_results
52+
higher_is_better: true
53+
- metric: Motion (Cam.)
54+
aggregation: !function utils.msr_aggregate_results
55+
higher_is_better: true
56+
- metric: MSR
57+
aggregation: !function utils.msr_aggregate_results
58+
higher_is_better: true
59+
- metric: average
60+
aggregation: !function utils.msr_aggregate_results
61+
higher_is_better: true
62+
63+
metadata:
64+
- version: 0.0

lmms_eval/tasks/mmsi_bench/utils.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
import re
3+
from PIL import Image
4+
import numpy as np
5+
import io
6+
import pandas as pd
7+
from collections import defaultdict
8+
from lmms_eval.filters.extraction import ExtendedRegexFilter
9+
from lmms_eval.filters.transformation import MapFilter
10+
import re
11+
12+
eval_logger = logging.getLogger("lmms-eval")
13+
14+
15+
def msr_doc_to_text(doc, lmms_eval_specific_kwargs=None):
16+
question = doc["question"].strip()
17+
if "pre_prompt" in lmms_eval_specific_kwargs and lmms_eval_specific_kwargs["pre_prompt"] != "":
18+
question = f"{lmms_eval_specific_kwargs['pre_prompt']}{question}"
19+
if "post_prompt" in lmms_eval_specific_kwargs and lmms_eval_specific_kwargs["post_prompt"] != "":
20+
question = f"{question}{lmms_eval_specific_kwargs['post_prompt']}"
21+
return question
22+
23+
24+
def msr_doc_to_visual(doc):
25+
# image_list = [image.convert("RGB") for image in doc["images"]]
26+
image_list = []
27+
for img_data in doc["images"]:
28+
image = Image.open(io.BytesIO(img_data))
29+
image = image.convert("RGB")
30+
image_list.append(image)
31+
return image_list
32+
33+
34+
35+
36+
def extract_single_choice_with_word_boundary(pred, gt):
37+
pattern_1 = r'``([^`]*)``'
38+
match = re.search(pattern_1, pred)
39+
if match:
40+
pred = match.group(1)
41+
42+
pattern_2 = r'`([^`]*)`'
43+
match = re.search(pattern_2, pred)
44+
if match:
45+
pred = match.group(1)
46+
47+
pattern_add = r'\{([^}]*)\}'
48+
match = re.search(pattern_add, pred)
49+
if match:
50+
pred = match.group(1)
51+
52+
pattern_3 = r'\b[A-D]\b(?!\s[a-zA-Z])'
53+
match = re.search(pattern_3, pred)
54+
if match:
55+
pred = match.group()
56+
else:
57+
return None
58+
59+
answer = gt.lower().replace("\n", " ").strip()
60+
predict = pred.lower().replace("\n", " ").strip()
61+
try:
62+
if answer == predict[0]:
63+
return 1.0
64+
elif predict[0] == "(" and answer == predict[1]:
65+
return 1.0
66+
elif predict[0:7] == "option " and answer == predict[7]:
67+
return 1.0
68+
elif predict[0:14] == "the answer is " and answer == predict[14]:
69+
return 1.0
70+
except Exception as e:
71+
return 0.0
72+
return 0.0
73+
74+
75+
76+
def msr_process_results(doc, results):
77+
"""
78+
Args:
79+
doc: a instance of the eval dataset
80+
results: [pred]
81+
Returns:
82+
a dictionary with key: metric name, value: metric value
83+
"""
84+
pred = results[0]
85+
gt = doc["answer"]
86+
87+
score = extract_single_choice_with_word_boundary(pred, gt)
88+
category = doc["question_type"]
89+
l2_category = doc["question_type"]
90+
if score is None:
91+
return {category: {"question_id": doc["id"], "l2_category": l2_category, "score": 0, "note": "can not find anwser"}, "average": {"question_id": doc["id"], "l2_category": l2_category, "score": 0, "note": "can not find anwser"}}
92+
return {category: {"question_id": doc["id"], "l2_category": l2_category, "score": score}, "average": {"question_id": doc["id"], "l2_category": l2_category, "score": score}}
93+
94+
95+
def msr_aggregate_results(results):
96+
"""
97+
Args:
98+
results: a list of values returned by process_results
99+
Returns:
100+
A score
101+
"""
102+
l2_category_scores = defaultdict(list)
103+
for result in results:
104+
score = result["score"]
105+
l2_category = result["l2_category"]
106+
l2_category_scores[l2_category].append(score)
107+
108+
l2_category_avg_score = {}
109+
for l2_category, scores in l2_category_scores.items():
110+
avg_score = sum(scores) / len(scores)
111+
l2_category_avg_score[l2_category] = avg_score
112+
eval_logger.info(f"{l2_category}: {avg_score:.2f}")
113+
114+
all_scores = [score for scores in l2_category_scores.values() for score in scores]
115+
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
116+
return avg_score
117+

0 commit comments

Comments
 (0)