Skip to content

Commit b7fbb1c

Browse files
committed
mmlu cot group agg result and mmlu flexible extract
1 parent 428feb1 commit b7fbb1c

File tree

4 files changed

+120
-3
lines changed

4 files changed

+120
-3
lines changed

lmms_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ task:
2626
- metric: acc
2727
weight_by_size: True
2828
aggregate_metric_list:
29-
- metric: acc
30-
weight_by_size: True
29+
- aggregation: mean
30+
metric: exact_match
31+
weight_by_size: true
32+
filter_list: flexible-extract
3133
metadata:
3234
version: 2

lmms_eval/tasks/mmlu_pro/_default_template_yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@ filter_list:
1515
regex_pattern: 'answer is \(?([ABCDEFGHIJ])\)?'
1616
# regex_pattern: r".*[aA]nswer:\s*([A-J])",
1717
- function: "take_first"
18+
- name: "strict-match"
19+
filter:
20+
- function: "regex"
21+
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
22+
- function: "take_first"
23+
- name: "flexible-extract"
24+
filter:
25+
- function: !function utils.MultiChoiceRegexFilter
26+
group_select: -1
27+
ignore_case: true
28+
ignore_punctuation: true
29+
regex_pattern: "(\\([A-Z]\\))"
30+
- function: "take_first"
1831
generation_kwargs:
1932
until:
2033
- "</s>"

lmms_eval/tasks/mmlu_pro/_mmlu_pro.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ aggregate_metric_list:
1818
- aggregation: mean
1919
metric: exact_match
2020
weight_by_size: true
21-
filter_list: custom-extract
21+
filter_list: flexible-extract
2222
metadata:
2323
version: 1.0

lmms_eval/tasks/mmlu_pro/utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import re
2+
import sys
3+
import unicodedata
14
from functools import partial
25

6+
from lmms_eval.filters.extraction import RegexFilter
7+
38
choices = [
49
"A",
510
"B",
@@ -58,3 +63,100 @@ def process_docs(dataset, subject):
5863
process_philosophy = partial(process_docs, subject="philosophy")
5964
process_physics = partial(process_docs, subject="physics")
6065
process_psychology = partial(process_docs, subject="psychology")
66+
67+
68+
class MultiChoiceRegexFilter(RegexFilter):
69+
""" """
70+
71+
def __init__(
72+
self,
73+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
74+
group_select=0,
75+
fallback: str = "[invalid]",
76+
ignore_case=False,
77+
ignore_punctuation=False,
78+
regexes_to_ignore=None,
79+
) -> None:
80+
"""
81+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
82+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
83+
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
84+
group_select: Selects the (group_select)th match from the findall result.
85+
ignore_case: Ignores the case during step 1 matching
86+
ignore_punctuation: Remove the punctuation during step 1 matching
87+
regexes_to_ignore: Remove these regexes during step 1 matching
88+
"""
89+
super().__init__(regex_pattern, group_select, fallback)
90+
self.ignore_case = ignore_case
91+
self.ignore_punctuation = ignore_punctuation
92+
self.regexes_to_ignore = regexes_to_ignore
93+
94+
def apply(self, resps, docs):
95+
# here, we assume we have a list, in which each element is
96+
# a list of model responses for some particular input/target pair.
97+
# so we process each of these (same input/target response sets)
98+
# independently (and keep them a list.)
99+
100+
def find_match(regex, resp, convert_dict={}):
101+
match = regex.findall(resp)
102+
if match:
103+
match = match[self.group_select]
104+
if isinstance(match, tuple):
105+
match = [m for m in match if m][0]
106+
match = match.strip()
107+
if match and match in convert_dict:
108+
match = convert_dict[match]
109+
return match
110+
111+
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
112+
113+
def filter_ignores(st):
114+
if self.regexes_to_ignore is not None:
115+
for s in self.regexes_to_ignore:
116+
st = re.sub(s, "", st)
117+
118+
if self.ignore_case:
119+
st = st.lower()
120+
121+
if self.ignore_punctuation:
122+
# https://stackoverflow.com/a/266162
123+
st = st.translate(punct_tbl)
124+
return st
125+
126+
filtered_resps = []
127+
128+
for r, doc in zip(resps, docs):
129+
fallback_regexes = []
130+
choice_to_alpha = {}
131+
next_alpha = "A"
132+
133+
without_paren_fallback_regexes = []
134+
without_paren_to_target = {}
135+
136+
choices = doc["options"]
137+
for c in choices:
138+
m = filter_ignores(c.strip())
139+
fallback_regexes.append(f"{re.escape(m)}")
140+
choice_to_alpha[m] = f"({next_alpha})"
141+
142+
without_paren_fallback_regexes.append(next_alpha)
143+
without_paren_to_target[next_alpha] = f"({next_alpha})"
144+
145+
next_alpha = chr(ord(next_alpha) + 1)
146+
fallback_regex = re.compile("|".join(fallback_regexes))
147+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
148+
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")
149+
150+
filtered = []
151+
for resp in r:
152+
match = find_match(self.regex, resp)
153+
if not match:
154+
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
155+
if not match:
156+
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
157+
if not match:
158+
match = self.fallback
159+
filtered.append(match)
160+
filtered_resps.append(filtered)
161+
162+
return filtered_resps

0 commit comments

Comments
 (0)