11import re
22from typing import Dict , Any , List , Optional
3+ from collections import Counter # For potential use in advanced text matching
34
5+ # Helper to normalize text, might be useful for content comparison
46def _normalize_text (text : str ) -> str :
57 """Lowercase and remove punctuation and extra whitespace."""
68 if not text :
79 return ""
810 text = text .lower ()
911 # Keep spaces and alphanumeric, remove others
10- text = re .sub (r'[^a-z0-9\s]' , '' , text )
12+ text = re .sub (r'[^a-z0-9\\ s]' , '' , text ) # Keep original regex
1113 text = ' ' .join (text .split ()) # Remove extra whitespace
1214 return text
1315
16+ # --- Component 1: Environment Reward Summation ---
17+ def _compute_env_reward_sum (trajectory : List [Dict ], reward_scale : float = 1.0 , reward_clip : Optional [float ] = None ) -> float :
18+ """
19+ Calculates the sum of rewards directly obtained from the environment steps.
20+ These are typically stored in the 'reward' field of turns from the 'env' or associated with 'gpt' turns.
21+ """
22+ raw_env_rewards = []
23+ # Iterate through the trajectory to find rewards associated with agent actions or env feedback
24+ for i , turn in enumerate (trajectory ):
25+ if turn .get ('from' ) == 'gpt' : # Agent's turn
26+ # Check if the reward for this action is stored in this turn
27+ if 'reward' in turn and isinstance (turn ['reward' ], (int , float )):
28+ raw_env_rewards .append (float (turn ['reward' ]))
29+ # Or if it's in the subsequent 'env' turn's info (less common for this arg structure)
30+ # This part might be double-counting if 'reward' is already on 'gpt' turn based on env step output.
31+ # elif i + 1 < len(trajectory) and trajectory[i+1].get('from') == 'env' and \\
32+ # 'reward' in trajectory[i+1] and isinstance(trajectory[i+1]['reward'], (int, float)):
33+ # raw_env_rewards.append(float(trajectory[i+1]['reward']))
34+
35+ sum_env_reward = sum (raw_env_rewards )
36+
37+ scaled_reward = sum_env_reward * reward_scale
38+ if reward_clip is not None :
39+ scaled_reward = max (- reward_clip , min (reward_clip , scaled_reward ))
40+
41+ return scaled_reward
42+
43+ # --- Component 2: Format Reward ---
44+ def _compute_format_reward (
45+ full_agent_generation_text : str ,
46+ max_reward : float ,
47+ min_reward : float ,
48+ check_all_tags : bool = True
49+ ) -> float :
50+ """
51+ Checks if the agent's output adheres to the specified format.
52+ Format: <think> ...<memory>...</memory> ...<plan>...</plan>...<think> <act>...</act>
53+ """
54+ if not full_agent_generation_text :
55+ return min_reward
56+
57+ text_to_check = re .sub (r'\\s+' , ' ' , full_agent_generation_text ).strip ()
58+ score = min_reward # Default to min_reward
59+
60+ if check_all_tags :
61+ # Pattern for the full sequence: <think>...<memory>...</memory>...<plan>...</plan>...<think>...</think>...<act>...</act>
62+ # This regex is complex and greedy. It tries to find one instance of this structure.
63+ # It allows any characters (including newlines due to re.DOTALL) within the tags and between them.
64+ full_pattern = r"<think>.*?</think>.*?<memory>.*?</memory>.*?<plan>.*?</plan>.*?<think>.*?</think>.*?<act>.*?</act>"
65+
66+ # Check for presence of individual tags for partial credit
67+ has_think = bool (re .search (r"<think>.*?</think>" , text_to_check , re .DOTALL ))
68+ has_memory = bool (re .search (r"<memory>.*?</memory>" , text_to_check , re .DOTALL ))
69+ has_plan = bool (re .search (r"<plan>.*?</plan>" , text_to_check , re .DOTALL ))
70+ has_act = bool (re .search (r"<act>.*?</act>" , text_to_check , re .DOTALL ))
71+ num_think_tags = len (re .findall (r"<think>.*?</think>" , text_to_check , re .DOTALL ))
72+
73+ if re .search (full_pattern , text_to_check , re .DOTALL ):
74+ score = max_reward
75+ elif num_think_tags >= 1 and has_memory and has_plan and has_act :
76+ # All key components present, but maybe not in the perfect full sequence or with extra stuff
77+ score = (max_reward + min_reward ) / 1.5 # Generous partial credit
78+ elif has_think and has_act : # Minimal: at least one think and one act
79+ score = (max_reward + min_reward ) / 2.0
80+ # else score remains min_reward
81+
82+ else : # Simpler check for just a final <think>...<act> sequence
83+ # Looks for a think block followed by an act block, possibly with whitespace.
84+ # This is usually for the last action segment.
85+ simple_pattern = r"<think>.*?</think>\s*<act>.*?</act>"
86+ if re .search (simple_pattern , text_to_check , re .DOTALL ):
87+ score = max_reward
88+ # else score remains min_reward
89+
90+ return score
91+
92+ # --- Component 3: Length Reward ---
93+ def _compute_length_reward (
94+ text_content : str ,
95+ max_reward : float ,
96+ min_reward : float ,
97+ target_len_words : int ,
98+ penalty_if_missing : bool = True ,
99+ too_short_penalty_factor : float = 0.5 ,
100+ too_long_penalty_factor : float = 0.5 ,
101+ tolerance_factor : float = 0.2 # e.g., +/- 20% of target_len_words
102+ ) -> float :
103+ """
104+ Rewards based on the length of the provided text content (in words).
105+ """
106+ if not text_content :
107+ return min_reward if penalty_if_missing else (min_reward + max_reward ) / 2
108+
109+ num_words = len (text_content .split ())
110+
111+ if num_words == 0 and penalty_if_missing :
112+ return min_reward
113+
114+ if target_len_words <= 0 : # Avoid division by zero if target length is invalid
115+ return (min_reward + max_reward ) / 2
116+
117+ lower_bound = target_len_words * (1 - tolerance_factor )
118+ upper_bound = target_len_words * (1 + tolerance_factor )
119+
120+ if lower_bound <= num_words <= upper_bound :
121+ return max_reward
122+ elif num_words < lower_bound :
123+ shortage_ratio = num_words / lower_bound
124+ # Reward decreases from max_reward as it gets shorter
125+ # Example: if num_words is 0, score is min_reward. If num_words is just below lower_bound, score is slightly less than max_reward.
126+ # This formula gives a linear ramp from min_reward to a point just below max_reward.
127+ # (1 - too_short_penalty_factor) controls how quickly it drops.
128+ # A simpler approach: score = max_reward - ( (lower_bound - num_words) / lower_bound ) * (max_reward - min_reward) * too_short_penalty_factor
129+ # Let's use: reward based on proximity to target, scaled by penalty factor for being too short.
130+ # Max penalty (max_reward - min_reward) * too_short_penalty_factor
131+ # Actual penalty = Max_penalty * (1 - shortage_ratio)
132+ penalty = (max_reward - min_reward ) * too_short_penalty_factor * (1.0 - shortage_ratio )
133+ return max (min_reward , max_reward - penalty )
134+
135+ else : # num_words > upper_bound
136+ # Penalize for being too long, similar logic
137+ excess_ratio = (num_words - upper_bound ) / upper_bound # How much percentage wise it's over
138+ penalty = (max_reward - min_reward ) * too_long_penalty_factor * min (1.0 , excess_ratio ) # Cap penalty effect
139+ return max (min_reward , max_reward - penalty )
140+
141+
142+ # --- Component 4: Ground Truth Trajectory Similarity ---
143+ def _extract_actions_from_trajectory (trajectory : List [Dict ]) -> List [str ]:
144+ """Extracts content from <act>...</act> tags from 'gpt' turns."""
145+ actions = []
146+ act_pattern = r"<act>(.*?)</act>"
147+ for turn in trajectory :
148+ if turn .get ('from' ) == 'gpt' :
149+ value = turn .get ('value' , '' )
150+ # Find all non-overlapping matches in the string
151+ matches = re .findall (act_pattern , value , re .DOTALL )
152+ actions .extend ([match .strip () for match in matches ])
153+ return actions
154+
155+ def _compute_gt_traj_similarity_reward (
156+ generated_actions : List [str ],
157+ ground_truth_actions : List [str ],
158+ max_reward : float ,
159+ min_reward : float
160+ ) -> float :
161+ """
162+ Compares a list of extracted agent actions with a list of ground truth actions.
163+ Uses a simple precision-like score based on sequential matching.
164+ """
165+ if not ground_truth_actions :
166+ # If no GT actions, it's hard to score. Neutral or max? Let's go neutral.
167+ return (max_reward + min_reward ) / 2
168+
169+ if not generated_actions : # Agent took no valid actions
170+ return min_reward
171+
172+ len_gt = len (ground_truth_actions )
173+
174+ matches = 0
175+ gt_idx = 0
176+ # Try to match generated actions against GT actions in order
177+ for gen_act in generated_actions :
178+ if gt_idx < len_gt and _normalize_text (gen_act ) == _normalize_text (ground_truth_actions [gt_idx ]):
179+ matches += 1
180+ gt_idx += 1 # Move to next GT action only if current one matched
181+
182+ # Similarity is the ratio of matched GT actions to total GT actions
183+ similarity = matches / len_gt if len_gt > 0 else 0.0
184+
185+ score = min_reward + (max_reward - min_reward ) * similarity
186+ return score
187+
14188
15189def compute_score (
16- env_name : str ,
190+ env_name : str ,
17191 ** kwargs
18192 ) -> float :
19193 """
20- Computes a score for an AgentGym environment based on rollout results passed via kwargs.
21-
22- It expects the full interaction trajectory and reward model info.
194+ Computes a composite score for an AgentGym environment trajectory.
23195
24196 Args:
25197 env_name: The name of the AgentGym environment.
26- **kwargs: Must contain 'trajectory' (List[Dict]) and 'reward_model_info' (Dict).
198+ **kwargs: Expected to contain:
199+ - 'trajectory' (List[Dict]): The agent's interaction log.
200+ Each dict: {'from': 'gpt'/'env', 'value': str, 'reward': float (from env step), ...}
201+ - 'reward_model_info' (Dict, optional): Contains parameters and ground truth. E.g.:
202+ - 'ground_truth_actions': List[str] (for GT trajectory comparison)
203+ - 'env_reward_weight', 'env_reward_scale', 'env_reward_clip'
204+ - 'format_reward_weight', 'format_max_r', 'format_min_r', 'format_check_all_tags'
205+ - 'length_reward_weight', 'length_max_r', 'length_min_r',
206+ 'length_target_words', 'length_penalty_if_missing',
207+ 'length_too_short_penalty_factor', 'length_too_long_penalty_factor', 'length_tolerance_factor'
208+ - 'gt_sim_reward_weight', 'gt_sim_max_r', 'gt_sim_min_r'
209+ - 'step' (int, optional): Current training step (for potential future scheduling).
27210
28211 Returns:
29- The calculated score as a float (typically 0.0 or 1.0) .
212+ The calculated composite score as a float.
30213 """
31214 trajectory = kwargs .get ('trajectory' )
32- reward_model_info = kwargs .get ('reward_model_info' )
33- env_name_lower = env_name .lower ()
34- score = 0.0
215+ reward_model_info = kwargs .get ('reward_model_info' ) if kwargs .get ('reward_model_info' ) is not None else {}
216+ current_step = kwargs .get ('step' , 0 )
35217
36- if not trajectory or not reward_model_info :
37- print (f"Warning: 'trajectory' or 'reward_model_info' missing in kwargs for env '{ env_name } '. Cannot compute score." )
218+ if not trajectory :
219+ print (f"Warning: 'trajectory' missing in kwargs for env '{ env_name } '. Cannot compute score. Returning 0.0 ." )
38220 return 0.0
39221
40- style = reward_model_info .get ("style" )
41-
42- try :
43- # --- WebShop Specific Logic ---
44- if env_name_lower in ["webshop" , "webarena" , "maze" , "wordle" , "alfworld" , "sciworld" , "babyai" , "textcraft" , "weather" , "movie" , "academia" , "todo" , "sheet" , "sqlgym" ]:
45- print (f"Warning: Trajectory-based scoring logic not yet implemented for env '{ env_name } '. Returning 0." )
46- # Implement specific scoring functions for these envs based on their trajectory structure and success criteria
47- score = 0.0
48-
49- else :
50- print (f"Warning: Unknown AgentGym environment '{ env_name } ' for reward scoring. Returning 0." )
222+ # --- Define default weights and parameters ---
223+ env_reward_weight = float (reward_model_info .get ('env_reward_weight' , 0.25 ))
224+ env_reward_scale = float (reward_model_info .get ('env_reward_scale' , 1.0 ))
225+ # Clip summed env reward; if None, no clipping
226+ env_reward_clip_val = reward_model_info .get ('env_reward_clip' , 5.0 )
227+ env_reward_clip = float (env_reward_clip_val ) if env_reward_clip_val is not None else None
228+
229+ format_reward_weight = float (reward_model_info .get ('format_reward_weight' , 0.25 ))
230+ format_max_r = float (reward_model_info .get ('format_max_r' , 1.0 ))
231+ format_min_r = float (reward_model_info .get ('format_min_r' , - 0.5 )) # Allow penalty for bad format
232+ format_check_all_tags = bool (reward_model_info .get ('format_check_all_tags' , True ))
233+
234+ length_reward_weight = float (reward_model_info .get ('length_reward_weight' , 0.15 ))
235+ length_max_r = float (reward_model_info .get ('length_max_r' , 0.5 )) # Max reward for good length might be less than 1
236+ length_min_r = float (reward_model_info .get ('length_min_r' , - 0.25 ))
237+ length_target_words = int (reward_model_info .get ('length_target_words' , 50 ))
238+ length_penalty_if_missing = bool (reward_model_info .get ('length_penalty_if_missing' , True ))
239+ length_too_short_penalty_factor = float (reward_model_info .get ('length_too_short_penalty_factor' , 0.5 ))
240+ length_too_long_penalty_factor = float (reward_model_info .get ('length_too_long_penalty_factor' , 0.5 ))
241+ length_tolerance_factor = float (reward_model_info .get ('length_tolerance_factor' , 0.3 ))
242+
243+
244+ gt_sim_reward_weight = float (reward_model_info .get ('gt_sim_reward_weight' , 0.35 ))
245+ gt_sim_max_r = float (reward_model_info .get ('gt_sim_max_r' , 1.0 ))
246+ gt_sim_min_r = float (reward_model_info .get ('gt_sim_min_r' , 0.0 ))
247+ ground_truth_actions = reward_model_info .get ('ground_truth_actions' , [])
248+
249+ # --- Component 1: Environment Reward Summation ---
250+ env_reward_score_component = _compute_env_reward_sum (trajectory , env_reward_scale , env_reward_clip )
251+
252+ # --- Consolidate Agent Text for Format/Length ---
253+ agent_generations_text = ""
254+ if isinstance (trajectory , list ):
255+ agent_generations_text = "\\ n" .join ([turn ['value' ] for turn in trajectory if turn .get ('from' ) == 'gpt' and isinstance (turn .get ('value' ), str )])
256+ else :
257+ print (f"Warning: Unexpected trajectory format: { type (trajectory )} . Format/length/GT rewards might be inaccurate." )
258+
259+ # --- Component 2: Format Reward ---
260+ format_score_component = _compute_format_reward (
261+ agent_generations_text , format_max_r , format_min_r , format_check_all_tags
262+ )
263+
264+ # --- Component 3: Length Reward (e.g., for combined <think> content) ---
265+ all_think_content = ""
266+ if agent_generations_text :
267+ think_pattern = r"<think>(.*?)</think>"
268+ for match in re .finditer (think_pattern , agent_generations_text , re .DOTALL ):
269+ all_think_content += match .group (1 ).strip () + " "
270+ all_think_content = all_think_content .strip ()
271+
272+ length_score_component = _compute_length_reward (
273+ all_think_content , length_max_r , length_min_r , length_target_words ,
274+ length_penalty_if_missing , length_too_short_penalty_factor , length_too_long_penalty_factor ,
275+ length_tolerance_factor
276+ )
277+
278+ # --- Component 4: Ground Truth Trajectory Similarity ---
279+ generated_actions = []
280+ if isinstance (trajectory , list ):
281+ generated_actions = _extract_actions_from_trajectory (trajectory )
282+
283+ gt_sim_score_component = _compute_gt_traj_similarity_reward (
284+ generated_actions , ground_truth_actions , gt_sim_max_r , gt_sim_min_r
285+ )
286+
287+ # --- Total Score ---
288+ total_score = (
289+ env_reward_weight * env_reward_score_component +
290+ format_reward_weight * format_score_component +
291+ length_reward_weight * length_score_component +
292+ gt_sim_reward_weight * gt_sim_score_component
293+ )
294+
295+ # Overall clipping/scaling if desired, e.g., to a standard range like [-1, 1] or [0, 1]
296+ # For example, if weights sum to 1, this might not be strictly needed unless components can be large.
297+ # total_score = max(-1.0, min(1.0, total_score)) # Example clip
51298
52- except Exception as e :
53- print (f"Error computing AgentGym score from trajectory for env='{ env_name } ', style='{ style } ': { e } " )
54- # Optionally log traceback: import traceback; print(traceback.format_exc())
55- score = 0.0 # Return 0 on error
299+ # For debugging, print individual scores:
300+ # print(f"[compute_score] Env '{env_name}': \\
301+ # EnvR_raw={env_reward_score_component:.2f} (w={env_reward_weight:.2f}), \\
302+ # FmtR_raw={format_score_component:.2f} (w={format_reward_weight:.2f}), \\
303+ # LenR_raw={length_score_component:.2f} (w={length_reward_weight:.2f}), \\
304+ # GtSimR_raw={gt_sim_score_component:.2f} (w={gt_sim_reward_weight:.2f}) --- TOTAL_raw={total_score:.2f}")
56305
57- return score
306+ return total_score
0 commit comments