Skip to content

Commit 97042fd

Browse files
authored
Reusing extract_answer_loose in accuracy_reward (#7)
1 parent f6a0ca6 commit 97042fd

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,13 @@ results = await asyncio.gather(
182182
# Compute rewards
183183
per_category_rewards = defaultdict(list)
184184
for row, result in zip(test_ds, results, strict=True):
185+
# NOTE: you can also use `ether0.rewards.accuracy_reward`,
186+
# but we decided to go a bit "lower level" for this demo
185187
reward_info = RewardFunctionInfo.model_validate(row["solution"])
186188
yhat = extract_answer_loose(result[0].text)
187-
reward = EVAL_FUNCTIONS[reward_info.fxn_name](yhat=yhat, y=reward_info.answer_info)
189+
reward = EVAL_FUNCTIONS[reward_info.fxn_name](
190+
yhat=yhat, y=reward_info.answer_info, test=True
191+
)
188192
per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)
189193

190194
for category, rewards in sorted(per_category_rewards.items()):

src/ether0/rewards.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility
2121
from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles
22-
from ether0.model_prompts import extract_thought_answer_strict
22+
from ether0.model_prompts import extract_answer_loose, extract_thought_answer_strict
2323
from ether0.models import RewardFunctionInfo, RewardReason
2424

2525
block = BlockLogs()
@@ -702,14 +702,11 @@ def accuracy_reward(
702702
reward_info = RewardFunctionInfo.model_validate(info)
703703
fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values())
704704
try:
705-
if test:
706-
answer: str | None = (
707-
content.split("<answer>")[1].split("</answer>")[0]
708-
if "<answer>" in content
709-
else content
710-
)
711-
else:
712-
answer = extract_thought_answer_strict(content, reasoning=reasoning)[1]
705+
answer: str | None = (
706+
extract_answer_loose(content)
707+
if test
708+
else extract_thought_answer_strict(content, reasoning=reasoning)[1]
709+
)
713710
if answer is not None:
714711
# During test time, see if full SMILES string was given as input
715712
if problem_type == "valid_mol_eval" and test:

0 commit comments

Comments
 (0)