Skip to content

Commit e12f4b7

Browse files
committed
Rebase and fix some issues
1 parent fa53d98 commit e12f4b7

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

tests/utils/eval_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_extract_answer(self):
3737
self.assertEqual(
3838
actual_output,
3939
expected_output,
40-
"Failed on input: '{input_str}'\nExpected: '{expected_output}', Got: '{actual_output}'",
40+
f"Failed on input: '{input_str}'\nExpected: '{expected_output}', Got: '{actual_output}'",
4141
)
4242

4343
def test_verify_math_answer(self):

trinity/common/workflows/eval_workflow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def run(self) -> List[Experience]:
7171
responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args)
7272

7373
for response in responses:
74-
accuracy, eval_details = verify_math_answer(
74+
if response.response_text is None or self.task.truth is None:
75+
continue
76+
77+
accuracy, _ = verify_math_answer(
7578
response_text=response.response_text, ground_truth=self.task.truth
7679
)
7780

trinity/utils/math_eval_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from word2number import w2n
2626

2727

28-
def verify_math_answer(response_text, ground_truth) -> Tuple[float, Dict[str, Any]]:
28+
def verify_math_answer(response_text: str, ground_truth: str) -> Tuple[float, Dict[str, Any]]:
2929
"""Strictly compare the equality of response and groundtruth."""
3030
# Parse the response
3131
parsed_prediction = extract_answer(response_text)
@@ -234,7 +234,7 @@ def extract_answer(response_text: str) -> Optional[str]:
234234
"inch",
235235
]
236236

237-
unit_texts.extend([t + "s" for t in unit_texts])
237+
unit_texts.extend([t + "s" for t in unit_texts if not t.endswith("s")])
238238

239239

240240
def strip_string(input_str: Optional[str]) -> Optional[str]:
@@ -319,7 +319,7 @@ def fix_fracs(string):
319319
else:
320320
try:
321321
assert len(substr) >= 2
322-
except Exception:
322+
except AssertionError:
323323
return string
324324
a = substr[0]
325325
b = substr[1]

0 commit comments

Comments
 (0)