@@ -92,7 +92,7 @@ async def run_async(self):
9292 experiences
9393 ) # NOTE: this metrics works only if the agent calls model once in each turn
9494
95- reward_dict = self .calculate_reward (answer_and_sources )
95+ reward_dict = await self .calculate_reward (answer_and_sources )
9696 reward = sum (reward_dict .values ())
9797
9898 for i , experience in enumerate (experiences ):
@@ -107,7 +107,7 @@ async def run_async(self):
107107 )
108108 return experiences
109109
110- def calculate_reward (self , answer_and_sources : Dict ) -> Dict [str , float ]:
110+ async def calculate_reward (self , answer_and_sources : Dict ) -> Dict [str , float ]:
111111 """Ref: calculate_reward in https://github.com/OpenPipe/ART/blob/main/dev/art-e/art_e/rollout.py#L64"""
112112 try :
113113 answer = answer_and_sources .get ("answer" , None )
@@ -140,7 +140,7 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
140140
141141 try :
142142 judge_model = self .auxiliary_models [0 ] if self .auxiliary_models else None
143- judge_response = judge_correctness (answer , self .query , judge_model )
143+ judge_response = await judge_correctness (answer , self .query , judge_model )
144144 rubric .answer_correct = judge_response
145145
146146 except Exception as e :
@@ -179,4 +179,4 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
179179 return result
180180
181181 self .logger .error (f"Rubric { rubric } not handled properly" )
182- raise ValueError ( "Rubric is not handled properly" )
182+ return { "accuracy" : 0.0 , "format" : 0.0 }
0 commit comments