File tree Expand file tree Collapse file tree 2 files changed +11
-10
lines changed
Expand file tree Collapse file tree 2 files changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -182,9 +182,13 @@ results = await asyncio.gather(
182182# Compute rewards
183183per_category_rewards = defaultdict(list )
184184for row, result in zip (test_ds, results, strict = True ):
185+ # NOTE : you can also use `ether0.rewards.accuracy_reward`,
186+ # but we decided to go a bit "lower level" for this demo
185187 reward_info = RewardFunctionInfo.model_validate(row[" solution" ])
186188 yhat = extract_answer_loose(result[0 ].text)
187- reward = EVAL_FUNCTIONS [reward_info.fxn_name](yhat = yhat, y = reward_info.answer_info)
189+ reward = EVAL_FUNCTIONS [reward_info.fxn_name](
190+ yhat = yhat, y = reward_info.answer_info, test = True
191+ )
188192 per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)
189193
190194for category, rewards in sorted (per_category_rewards.items()):
Original file line number Diff line number Diff line change 1919
2020from ether0 .clients import fetch_forward_rxn , fetch_purchasable , fetch_solubility
2121from ether0 .data import is_reasonable_fp , is_reasonable_ring_system , mol_from_smiles
22- from ether0 .model_prompts import extract_thought_answer_strict
22+ from ether0 .model_prompts import extract_answer_loose , extract_thought_answer_strict
2323from ether0 .models import RewardFunctionInfo , RewardReason
2424
2525block = BlockLogs ()
@@ -702,14 +702,11 @@ def accuracy_reward(
702702 reward_info = RewardFunctionInfo .model_validate (info )
703703 fxn_name , answer_info , problem_type = tuple (reward_info .model_dump ().values ())
704704 try :
705- if test :
706- answer : str | None = (
707- content .split ("<answer>" )[1 ].split ("</answer>" )[0 ]
708- if "<answer>" in content
709- else content
710- )
711- else :
712- answer = extract_thought_answer_strict (content , reasoning = reasoning )[1 ]
705+ answer : str | None = (
706+ extract_answer_loose (content )
707+ if test
708+ else extract_thought_answer_strict (content , reasoning = reasoning )[1 ]
709+ )
713710 if answer is not None :
714711 # During test time, see if full SMILES string was given as input
715712 if problem_type == "valid_mol_eval" and test :
You can’t perform that action at this time.
0 commit comments