Skip to content

Commit 2c79359

Browse files
authored
add reward_fn (#65)
1 parent 1ff6375 commit 2c79359

File tree

2 files changed

+277
-69
lines changed

2 files changed

+277
-69
lines changed

verl/trainer/ppo/ray_trainer.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242
import re
4343
from openmanus_rl.llm_agent.openmanus import OpenManusAgent, AgentConfig
4444
from verl.utils.reward_score import SUPPORTED_REWARD_SCORE_FNS
45-
from verl.utils.reward_score.agentgym import compute_score as agentgym_compute_score
46-
from verl.utils.reward_score.reward_components import RewardComposer, GoalReward, LengthPenalty, FormatReward
4745
from verl.utils.tracking import Tracking
4846

4947
import ray
@@ -518,52 +516,13 @@ def __init__(
518516

519517
self._create_dataloader()
520518
self._init_logger()
521-
self._init_reward_composer()
522519

523520
def _init_logger(self):
524521
self.logger = Tracking(project_name=self.config.trainer.project_name,
525522
experiment_name=self.config.trainer.experiment_name,
526523
default_backend=self.config.trainer.logger,
527524
config=OmegaConf.to_container(self.config, resolve=True))
528525

529-
def _init_reward_composer(self):
530-
"""Initializes the RewardComposer based on the configuration."""
531-
components = []
532-
cfg = self.reward_component_config
533-
print(f"[Trainer._init_reward_composer] Initializing with config: {cfg}")
534-
535-
# --- Build Reward Components List ---
536-
# Example: Dynamically add components based on config
537-
if cfg.get('goal_reward', {}).get('enabled', True):
538-
components.append(GoalReward(weight=cfg['goal_reward'].get('weight', 1.0)))
539-
print(" - Added GoalReward")
540-
541-
if cfg.get('length_penalty', {}).get('enabled', False):
542-
lp_cfg = cfg['length_penalty']
543-
components.append(LengthPenalty(
544-
weight=lp_cfg.get('weight', -0.01),
545-
max_length=lp_cfg.get('max_length', 500),
546-
min_length=lp_cfg.get('min_length', 10),
547-
penalty_type=lp_cfg.get('penalty_type', "linear")
548-
))
549-
print(" - Added LengthPenalty")
550-
551-
if cfg.get('format_reward', {}).get('enabled', False):
552-
fmt_cfg = cfg['format_reward']
553-
# Get patterns specific to the current env or use default
554-
patterns = fmt_cfg.get('patterns_by_env', {}).get(
555-
self.config.data.env_name, # Assumes env_name is available in self.config.data
556-
fmt_cfg.get('patterns_by_env', {}).get('default', [])
557-
)
558-
components.append(FormatReward(
559-
weight=fmt_cfg.get('weight', 0.2),
560-
required_patterns=patterns
561-
))
562-
print(f" - Added FormatReward with patterns: {patterns}")
563-
564-
self.reward_composer = RewardComposer(components=components)
565-
print(f"[Trainer._init_reward_composer] Composer initialized with {len(components)} components.")
566-
567526
def _create_dataloader(self):
568527
from torch.utils.data import DataLoader
569528
# TODO: we have to make sure the batch size is divisible by the dp size
Lines changed: 277 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,306 @@
11
import re
22
from typing import Dict, Any, List, Optional
3+
from collections import Counter # For potential use in advanced text matching
34

5+
# Helper to normalize text, might be useful for content comparison
46
def _normalize_text(text: str) -> str:
57
"""Lowercase and remove punctuation and extra whitespace."""
68
if not text:
79
return ""
810
text = text.lower()
911
# Keep spaces and alphanumeric, remove others
10-
text = re.sub(r'[^a-z0-9\s]', '', text)
12+
text = re.sub(r'[^a-z0-9\\s]', '', text) # Keep original regex
1113
text = ' '.join(text.split()) # Remove extra whitespace
1214
return text
1315

16+
# --- Component 1: Environment Reward Summation ---
17+
def _compute_env_reward_sum(trajectory: List[Dict], reward_scale: float = 1.0, reward_clip: Optional[float] = None) -> float:
18+
"""
19+
Calculates the sum of rewards directly obtained from the environment steps.
20+
These are typically stored in the 'reward' field of turns from the 'env' or associated with 'gpt' turns.
21+
"""
22+
raw_env_rewards = []
23+
# Iterate through the trajectory to find rewards associated with agent actions or env feedback
24+
for i, turn in enumerate(trajectory):
25+
if turn.get('from') == 'gpt': # Agent's turn
26+
# Check if the reward for this action is stored in this turn
27+
if 'reward' in turn and isinstance(turn['reward'], (int, float)):
28+
raw_env_rewards.append(float(turn['reward']))
29+
# Or if it's in the subsequent 'env' turn's info (less common for this arg structure)
30+
# This part might be double-counting if 'reward' is already on 'gpt' turn based on env step output.
31+
# elif i + 1 < len(trajectory) and trajectory[i+1].get('from') == 'env' and \\
32+
# 'reward' in trajectory[i+1] and isinstance(trajectory[i+1]['reward'], (int, float)):
33+
# raw_env_rewards.append(float(trajectory[i+1]['reward']))
34+
35+
sum_env_reward = sum(raw_env_rewards)
36+
37+
scaled_reward = sum_env_reward * reward_scale
38+
if reward_clip is not None:
39+
scaled_reward = max(-reward_clip, min(reward_clip, scaled_reward))
40+
41+
return scaled_reward
42+
43+
# --- Component 2: Format Reward ---
44+
def _compute_format_reward(
45+
full_agent_generation_text: str,
46+
max_reward: float,
47+
min_reward: float,
48+
check_all_tags: bool = True
49+
) -> float:
50+
"""
51+
Checks if the agent's output adheres to the specified format.
52+
Format: <think> ...<memory>...</memory> ...<plan>...</plan>...<think> <act>...</act>
53+
"""
54+
if not full_agent_generation_text:
55+
return min_reward
56+
57+
text_to_check = re.sub(r'\\s+', ' ', full_agent_generation_text).strip()
58+
score = min_reward # Default to min_reward
59+
60+
if check_all_tags:
61+
# Pattern for the full sequence: <think>...<memory>...</memory>...<plan>...</plan>...<think>...</think>...<act>...</act>
62+
# This regex is complex and greedy. It tries to find one instance of this structure.
63+
# It allows any characters (including newlines due to re.DOTALL) within the tags and between them.
64+
full_pattern = r"<think>.*?</think>.*?<memory>.*?</memory>.*?<plan>.*?</plan>.*?<think>.*?</think>.*?<act>.*?</act>"
65+
66+
# Check for presence of individual tags for partial credit
67+
has_think = bool(re.search(r"<think>.*?</think>", text_to_check, re.DOTALL))
68+
has_memory = bool(re.search(r"<memory>.*?</memory>", text_to_check, re.DOTALL))
69+
has_plan = bool(re.search(r"<plan>.*?</plan>", text_to_check, re.DOTALL))
70+
has_act = bool(re.search(r"<act>.*?</act>", text_to_check, re.DOTALL))
71+
num_think_tags = len(re.findall(r"<think>.*?</think>", text_to_check, re.DOTALL))
72+
73+
if re.search(full_pattern, text_to_check, re.DOTALL):
74+
score = max_reward
75+
elif num_think_tags >= 1 and has_memory and has_plan and has_act:
76+
# All key components present, but maybe not in the perfect full sequence or with extra stuff
77+
score = (max_reward + min_reward) / 1.5 # Generous partial credit
78+
elif has_think and has_act : # Minimal: at least one think and one act
79+
score = (max_reward + min_reward) / 2.0
80+
# else score remains min_reward
81+
82+
else: # Simpler check for just a final <think>...<act> sequence
83+
# Looks for a think block followed by an act block, possibly with whitespace.
84+
# This is usually for the last action segment.
85+
simple_pattern = r"<think>.*?</think>\s*<act>.*?</act>"
86+
if re.search(simple_pattern, text_to_check, re.DOTALL):
87+
score = max_reward
88+
# else score remains min_reward
89+
90+
return score
91+
92+
# --- Component 3: Length Reward ---
93+
def _compute_length_reward(
94+
text_content: str,
95+
max_reward: float,
96+
min_reward: float,
97+
target_len_words: int,
98+
penalty_if_missing: bool = True,
99+
too_short_penalty_factor: float = 0.5,
100+
too_long_penalty_factor: float = 0.5,
101+
tolerance_factor: float = 0.2 # e.g., +/- 20% of target_len_words
102+
) -> float:
103+
"""
104+
Rewards based on the length of the provided text content (in words).
105+
"""
106+
if not text_content:
107+
return min_reward if penalty_if_missing else (min_reward + max_reward) / 2
108+
109+
num_words = len(text_content.split())
110+
111+
if num_words == 0 and penalty_if_missing:
112+
return min_reward
113+
114+
if target_len_words <=0: # Avoid division by zero if target length is invalid
115+
return (min_reward + max_reward) / 2
116+
117+
lower_bound = target_len_words * (1 - tolerance_factor)
118+
upper_bound = target_len_words * (1 + tolerance_factor)
119+
120+
if lower_bound <= num_words <= upper_bound:
121+
return max_reward
122+
elif num_words < lower_bound:
123+
shortage_ratio = num_words / lower_bound
124+
# Reward decreases from max_reward as it gets shorter
125+
# Example: if num_words is 0, score is min_reward. If num_words is just below lower_bound, score is slightly less than max_reward.
126+
# This formula gives a linear ramp from min_reward to a point just below max_reward.
127+
# (1 - too_short_penalty_factor) controls how quickly it drops.
128+
# A simpler approach: score = max_reward - ( (lower_bound - num_words) / lower_bound ) * (max_reward - min_reward) * too_short_penalty_factor
129+
# Let's use: reward based on proximity to target, scaled by penalty factor for being too short.
130+
# Max penalty (max_reward - min_reward) * too_short_penalty_factor
131+
# Actual penalty = Max_penalty * (1 - shortage_ratio)
132+
penalty = (max_reward - min_reward) * too_short_penalty_factor * (1.0 - shortage_ratio)
133+
return max(min_reward, max_reward - penalty)
134+
135+
else: # num_words > upper_bound
136+
# Penalize for being too long, similar logic
137+
excess_ratio = (num_words - upper_bound) / upper_bound # How much percentage wise it's over
138+
penalty = (max_reward - min_reward) * too_long_penalty_factor * min(1.0, excess_ratio) # Cap penalty effect
139+
return max(min_reward, max_reward - penalty)
140+
141+
142+
# --- Component 4: Ground Truth Trajectory Similarity ---
143+
def _extract_actions_from_trajectory(trajectory: List[Dict]) -> List[str]:
144+
"""Extracts content from <act>...</act> tags from 'gpt' turns."""
145+
actions = []
146+
act_pattern = r"<act>(.*?)</act>"
147+
for turn in trajectory:
148+
if turn.get('from') == 'gpt':
149+
value = turn.get('value', '')
150+
# Find all non-overlapping matches in the string
151+
matches = re.findall(act_pattern, value, re.DOTALL)
152+
actions.extend([match.strip() for match in matches])
153+
return actions
154+
155+
def _compute_gt_traj_similarity_reward(
156+
generated_actions: List[str],
157+
ground_truth_actions: List[str],
158+
max_reward: float,
159+
min_reward: float
160+
) -> float:
161+
"""
162+
Compares a list of extracted agent actions with a list of ground truth actions.
163+
Uses a simple precision-like score based on sequential matching.
164+
"""
165+
if not ground_truth_actions:
166+
# If no GT actions, it's hard to score. Neutral or max? Let's go neutral.
167+
return (max_reward + min_reward) / 2
168+
169+
if not generated_actions: # Agent took no valid actions
170+
return min_reward
171+
172+
len_gt = len(ground_truth_actions)
173+
174+
matches = 0
175+
gt_idx = 0
176+
# Try to match generated actions against GT actions in order
177+
for gen_act in generated_actions:
178+
if gt_idx < len_gt and _normalize_text(gen_act) == _normalize_text(ground_truth_actions[gt_idx]):
179+
matches += 1
180+
gt_idx += 1 # Move to next GT action only if current one matched
181+
182+
# Similarity is the ratio of matched GT actions to total GT actions
183+
similarity = matches / len_gt if len_gt > 0 else 0.0
184+
185+
score = min_reward + (max_reward - min_reward) * similarity
186+
return score
187+
14188

15189
def compute_score(
16-
env_name: str,
190+
env_name: str,
17191
**kwargs
18192
) -> float:
19193
"""
20-
Computes a score for an AgentGym environment based on rollout results passed via kwargs.
21-
22-
It expects the full interaction trajectory and reward model info.
194+
Computes a composite score for an AgentGym environment trajectory.
23195
24196
Args:
25197
env_name: The name of the AgentGym environment.
26-
**kwargs: Must contain 'trajectory' (List[Dict]) and 'reward_model_info' (Dict).
198+
**kwargs: Expected to contain:
199+
- 'trajectory' (List[Dict]): The agent's interaction log.
200+
Each dict: {'from': 'gpt'/'env', 'value': str, 'reward': float (from env step), ...}
201+
- 'reward_model_info' (Dict, optional): Contains parameters and ground truth. E.g.:
202+
- 'ground_truth_actions': List[str] (for GT trajectory comparison)
203+
- 'env_reward_weight', 'env_reward_scale', 'env_reward_clip'
204+
- 'format_reward_weight', 'format_max_r', 'format_min_r', 'format_check_all_tags'
205+
- 'length_reward_weight', 'length_max_r', 'length_min_r',
206+
'length_target_words', 'length_penalty_if_missing',
207+
'length_too_short_penalty_factor', 'length_too_long_penalty_factor', 'length_tolerance_factor'
208+
- 'gt_sim_reward_weight', 'gt_sim_max_r', 'gt_sim_min_r'
209+
- 'step' (int, optional): Current training step (for potential future scheduling).
27210
28211
Returns:
29-
The calculated score as a float (typically 0.0 or 1.0).
212+
The calculated composite score as a float.
30213
"""
31214
trajectory = kwargs.get('trajectory')
32-
reward_model_info = kwargs.get('reward_model_info')
33-
env_name_lower = env_name.lower()
34-
score = 0.0
215+
reward_model_info = kwargs.get('reward_model_info') if kwargs.get('reward_model_info') is not None else {}
216+
current_step = kwargs.get('step', 0)
35217

36-
if not trajectory or not reward_model_info:
37-
print(f"Warning: 'trajectory' or 'reward_model_info' missing in kwargs for env '{env_name}'. Cannot compute score.")
218+
if not trajectory:
219+
print(f"Warning: 'trajectory' missing in kwargs for env '{env_name}'. Cannot compute score. Returning 0.0.")
38220
return 0.0
39221

40-
style = reward_model_info.get("style")
41-
42-
try:
43-
# --- WebShop Specific Logic ---
44-
if env_name_lower in ["webshop", "webarena", "maze", "wordle", "alfworld", "sciworld", "babyai", "textcraft", "weather", "movie", "academia", "todo", "sheet", "sqlgym"]:
45-
print(f"Warning: Trajectory-based scoring logic not yet implemented for env '{env_name}'. Returning 0.")
46-
# Implement specific scoring functions for these envs based on their trajectory structure and success criteria
47-
score = 0.0
48-
49-
else:
50-
print(f"Warning: Unknown AgentGym environment '{env_name}' for reward scoring. Returning 0.")
222+
# --- Define default weights and parameters ---
223+
env_reward_weight = float(reward_model_info.get('env_reward_weight', 0.25))
224+
env_reward_scale = float(reward_model_info.get('env_reward_scale', 1.0))
225+
# Clip summed env reward; if None, no clipping
226+
env_reward_clip_val = reward_model_info.get('env_reward_clip', 5.0)
227+
env_reward_clip = float(env_reward_clip_val) if env_reward_clip_val is not None else None
228+
229+
format_reward_weight = float(reward_model_info.get('format_reward_weight', 0.25))
230+
format_max_r = float(reward_model_info.get('format_max_r', 1.0))
231+
format_min_r = float(reward_model_info.get('format_min_r', -0.5)) # Allow penalty for bad format
232+
format_check_all_tags = bool(reward_model_info.get('format_check_all_tags', True))
233+
234+
length_reward_weight = float(reward_model_info.get('length_reward_weight', 0.15))
235+
length_max_r = float(reward_model_info.get('length_max_r', 0.5)) # Max reward for good length might be less than 1
236+
length_min_r = float(reward_model_info.get('length_min_r', -0.25))
237+
length_target_words = int(reward_model_info.get('length_target_words', 50))
238+
length_penalty_if_missing = bool(reward_model_info.get('length_penalty_if_missing', True))
239+
length_too_short_penalty_factor = float(reward_model_info.get('length_too_short_penalty_factor', 0.5))
240+
length_too_long_penalty_factor = float(reward_model_info.get('length_too_long_penalty_factor', 0.5))
241+
length_tolerance_factor = float(reward_model_info.get('length_tolerance_factor', 0.3))
242+
243+
244+
gt_sim_reward_weight = float(reward_model_info.get('gt_sim_reward_weight', 0.35))
245+
gt_sim_max_r = float(reward_model_info.get('gt_sim_max_r', 1.0))
246+
gt_sim_min_r = float(reward_model_info.get('gt_sim_min_r', 0.0))
247+
ground_truth_actions = reward_model_info.get('ground_truth_actions', [])
248+
249+
# --- Component 1: Environment Reward Summation ---
250+
env_reward_score_component = _compute_env_reward_sum(trajectory, env_reward_scale, env_reward_clip)
251+
252+
# --- Consolidate Agent Text for Format/Length ---
253+
agent_generations_text = ""
254+
if isinstance(trajectory, list):
255+
agent_generations_text = "\\n".join([turn['value'] for turn in trajectory if turn.get('from') == 'gpt' and isinstance(turn.get('value'), str)])
256+
else:
257+
print(f"Warning: Unexpected trajectory format: {type(trajectory)}. Format/length/GT rewards might be inaccurate.")
258+
259+
# --- Component 2: Format Reward ---
260+
format_score_component = _compute_format_reward(
261+
agent_generations_text, format_max_r, format_min_r, format_check_all_tags
262+
)
263+
264+
# --- Component 3: Length Reward (e.g., for combined <think> content) ---
265+
all_think_content = ""
266+
if agent_generations_text:
267+
think_pattern = r"<think>(.*?)</think>"
268+
for match in re.finditer(think_pattern, agent_generations_text, re.DOTALL):
269+
all_think_content += match.group(1).strip() + " "
270+
all_think_content = all_think_content.strip()
271+
272+
length_score_component = _compute_length_reward(
273+
all_think_content, length_max_r, length_min_r, length_target_words,
274+
length_penalty_if_missing, length_too_short_penalty_factor, length_too_long_penalty_factor,
275+
length_tolerance_factor
276+
)
277+
278+
# --- Component 4: Ground Truth Trajectory Similarity ---
279+
generated_actions = []
280+
if isinstance(trajectory, list):
281+
generated_actions = _extract_actions_from_trajectory(trajectory)
282+
283+
gt_sim_score_component = _compute_gt_traj_similarity_reward(
284+
generated_actions, ground_truth_actions, gt_sim_max_r, gt_sim_min_r
285+
)
286+
287+
# --- Total Score ---
288+
total_score = (
289+
env_reward_weight * env_reward_score_component +
290+
format_reward_weight * format_score_component +
291+
length_reward_weight * length_score_component +
292+
gt_sim_reward_weight * gt_sim_score_component
293+
)
294+
295+
# Overall clipping/scaling if desired, e.g., to a standard range like [-1, 1] or [0, 1]
296+
# For example, if weights sum to 1, this might not be strictly needed unless components can be large.
297+
# total_score = max(-1.0, min(1.0, total_score)) # Example clip
51298

52-
except Exception as e:
53-
print(f"Error computing AgentGym score from trajectory for env='{env_name}', style='{style}': {e}")
54-
# Optionally log traceback: import traceback; print(traceback.format_exc())
55-
score = 0.0 # Return 0 on error
299+
# For debugging, print individual scores:
300+
# print(f"[compute_score] Env '{env_name}': \\
301+
# EnvR_raw={env_reward_score_component:.2f} (w={env_reward_weight:.2f}), \\
302+
# FmtR_raw={format_score_component:.2f} (w={format_reward_weight:.2f}), \\
303+
# LenR_raw={length_score_component:.2f} (w={length_reward_weight:.2f}), \\
304+
# GtSimR_raw={gt_sim_score_component:.2f} (w={gt_sim_reward_weight:.2f}) --- TOTAL_raw={total_score:.2f}")
56305

57-
return score
306+
return total_score

0 commit comments

Comments
 (0)