@@ -23,6 +23,7 @@ def __init__(self, envs, projection_f, config):
2323 self .ground_truths = []
2424 self .step_counts = []
2525 self .task_completed = []
26+ self .task_success = []
2627
2728 def reset (self ):
2829 """Reset environment and get new tasks"""
@@ -36,6 +37,7 @@ def reset(self):
3637 batch_size = len (self .current_tasks )
3738 self .step_counts = [0 ] * batch_size
3839 self .task_completed = [False ] * batch_size
40+ self .task_success = [False ] * batch_size
3941
4042 # Initialize memory
4143 self .memory .reset (batch_size = batch_size )
@@ -59,7 +61,7 @@ def step(self, text_actions: List[str]):
5961 for i , (action , valid ) in enumerate (zip (actions , valids )):
6062 if self .task_completed [i ]:
6163 observations .append ("Task completed." )
62- infos .append ({'is_action_valid' : True , 'won' : True })
64+ infos .append ({'is_action_valid' : True , 'won' : self . task_success [ i ] })
6365 continue
6466
6567 self .step_counts [i ] += 1
@@ -70,14 +72,19 @@ def step(self, text_actions: List[str]):
7072
7173 # Check completion
7274 if self ._is_completion_action (action ):
75+ is_correct = self ._evaluate_answer (action , i )
76+ self .task_success [i ] = is_correct
7377 self .task_completed [i ] = True
7478 dones [i ] = True
79+ obs_feedback = "\n \n Evaluation: final answer matches the ground truth." if is_correct else "\n \n Evaluation: final answer does not match the ground truth."
80+ observations [- 1 ] = obs + obs_feedback
7581 elif self .step_counts [i ] >= self .config .env .max_steps :
7682 obs += "\n \n Maximum steps reached. Please provide your final answer in <answer></answer> tags."
7783 dones [i ] = True
84+ observations [- 1 ] = obs
7885
7986 info ['is_action_valid' ] = to_numpy (valid )
80- info ['won' ] = self .task_completed [i ]
87+ info ['won' ] = self .task_success [i ]
8188 info ['step_count' ] = self .step_counts [i ]
8289 infos .append (info )
8390
@@ -125,28 +132,78 @@ def _is_completion_action(self, action: str) -> bool:
125132 """Check if action indicates task completion"""
126133 return action .startswith ("FINAL_ANSWER:" ) or "<answer>" in action
127134
135+ def _evaluate_answer (self , action : str , batch_idx : int ) -> bool :
136+ """Compare model answer with ground truth"""
137+ predicted = self ._extract_answer_text (action )
138+ ground_truth = self .ground_truths [batch_idx ]
139+ return self ._normalize_answer (predicted ) == self ._normalize_answer (ground_truth )
140+
141+ @staticmethod
142+ def _extract_answer_text (action : str ) -> str :
143+ """Extract answer text from action string"""
144+ if action .startswith ("FINAL_ANSWER:" ):
145+ return action .split ("FINAL_ANSWER:" , 1 )[1 ].strip ()
146+
147+ match = re .search (r"<answer>(.*?)</answer>" , action , re .DOTALL )
148+ if match :
149+ return match .group (1 ).strip ()
150+ return action .strip ()
151+
152+ @staticmethod
153+ def _normalize_answer (text : str ) -> str :
154+ """Normalize answer string for comparison"""
155+ normalized = re .sub (r"\s+" , " " , text ).strip ().lower ()
156+ normalized = normalized .strip (".,!?:;\" " )
157+ return normalized
158+
128159 def build_text_obs (self , observations : List [str ] = None , init : bool = False ) -> List [str ]:
129160 """Build text observations for agent"""
130161 batch_size = len (self .current_tasks )
131162 postprocess_text_obs = []
132-
163+ max_steps = getattr (self .config .env , "max_steps" , None )
164+ history_length_cfg = getattr (self .config .env , "history_length" , 0 )
165+
166+ if not init and history_length_cfg > 0 :
167+ memory_contexts , valid_lens = self .memory .fetch (
168+ history_length_cfg ,
169+ obs_key = "text_obs" ,
170+ action_key = "action" ,
171+ )
172+ else :
173+ memory_contexts = ["" ] * batch_size
174+ valid_lens = [0 ] * batch_size
175+
133176 for i in range (batch_size ):
134- if init or self .config .env .history_length <= 0 :
177+ current_obs = observations [i ] if observations else "Continue with your task."
178+ should_use_last_step = (
179+ not init
180+ and not self .task_completed [i ]
181+ and max_steps is not None
182+ and self .step_counts [i ] >= max_steps - 1
183+ )
184+
185+ if init :
135186 obs = TOOL_USE_TEMPLATE_NO_HIS .format (
136187 task_description = self .current_tasks [i ],
137188 available_tools = self .tool_metadata ,
138189 current_observation = "Start working on the task."
139190 )
140- else :
141- # Get history
142- memory_contexts , valid_lens = self .memory .fetch (
143- self .config .env .history_length ,
144- obs_key = "text_obs" ,
145- action_key = "action"
191+ elif should_use_last_step :
192+ obs = TOOL_USE_TEMPLATE_LAST_STEP .format (
193+ task_description = self .current_tasks [i ],
194+ step_count = self .step_counts [i ],
195+ history_length = valid_lens [i ],
196+ action_history = memory_contexts [i ],
197+ current_step = self .step_counts [i ] + 1 ,
198+ current_observation = current_obs ,
146199 )
147-
148- current_obs = observations [i ] if observations else "Continue with your task."
149-
200+ elif history_length_cfg <= 0 :
201+ obs = TOOL_USE_TEMPLATE_NO_HIS .format (
202+ task_description = self .current_tasks [i ],
203+ available_tools = self .tool_metadata ,
204+ current_observation = current_obs ,
205+ )
206+ else :
150207 obs = TOOL_USE_TEMPLATE .format (
151208 task_description = self .current_tasks [i ],
152209 step_count = self .step_counts [i ],
@@ -156,7 +213,7 @@ def build_text_obs(self, observations: List[str] = None, init: bool = False) ->
156213 current_observation = current_obs ,
157214 available_tools = self .tool_metadata
158215 )
159-
216+
160217 postprocess_text_obs .append (obs )
161218
162219 return postprocess_text_obs
0 commit comments