Skip to content

Commit af34b3d

Browse files
committed
Reward Redsgin & openmanus redesign
1 parent 78e62ac commit af34b3d

File tree

6 files changed

+1034
-436
lines changed

6 files changed

+1034
-436
lines changed
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
import torch
2+
import re
3+
from collections import defaultdict
4+
import os
5+
from typing import List, Dict, Any, Tuple
6+
from dataclasses import dataclass
7+
from .tensor_helper import TensorHelper, TensorConfig
8+
from ragen.utils import set_seed
9+
from ragen.utils.plot import (
10+
save_trajectory_to_output,
11+
parse_llm_output
12+
)
13+
from verl import DataProto
14+
from verl.utils.tracking import Tracking
15+
import shutil
16+
17+
@dataclass
18+
class GenerationConfig:
19+
max_turns: int
20+
max_start_length: int
21+
max_prompt_length: int
22+
max_response_length: int
23+
max_obs_length: int
24+
logging: dict
25+
num_gpus: int
26+
no_think_rl: bool=False
27+
state_masking: bool=False
28+
start_state_marker: str="<start-state>"
29+
end_state_marker: str="<end-state>"
30+
31+
class LLMGenerationManager:
32+
def __init__(
33+
self,
34+
tokenizer,
35+
actor_rollout_wg,
36+
env_class,
37+
config: GenerationConfig,
38+
logger: Tracking,
39+
is_validation: bool = False,
40+
):
41+
self.tokenizer = tokenizer
42+
self.actor_rollout_wg = actor_rollout_wg
43+
self.env_class = env_class
44+
self.config = config
45+
self.logger = logger
46+
self.is_validation = is_validation
47+
48+
self.tensor_fn = TensorHelper(TensorConfig(
49+
pad_token_id=tokenizer.pad_token_id,
50+
max_prompt_length=config.max_prompt_length,
51+
max_obs_length=config.max_obs_length,
52+
max_start_length=config.max_start_length
53+
))
54+
55+
def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
56+
"""Tokenize a batch of responses."""
57+
return self.tokenizer(
58+
responses,
59+
add_special_tokens=False,
60+
return_tensors='pt',
61+
padding="longest"
62+
)['input_ids']
63+
64+
@staticmethod
65+
def _process_answer_tag(responses_str):
66+
"""
67+
Process a list of response strings to keep only the first <answer></answer> tag pair
68+
while preserving the rest of the string content.
69+
70+
Args:
71+
responses_str (List[str]): List of response strings potentially containing answer tags
72+
73+
Returns:
74+
List[str]: Processed responses with only first answer tag pair preserved
75+
"""
76+
def process_single_response(resp):
77+
# If no answer tags present, return original string
78+
if '<answer>' not in resp or '</answer>' not in resp:
79+
return resp
80+
81+
# Find the first complete <answer> tag pair
82+
pattern = r'<answer>.*?</answer>'
83+
match = re.search(pattern, resp, re.DOTALL)
84+
85+
if not match:
86+
return resp
87+
88+
# Get the matched answer tag content
89+
answer_content = match.group(0)
90+
91+
# Replace all subsequent answer tag pairs with their content
92+
rest_of_string = resp[match.end():]
93+
cleaned_rest = re.sub(r'<answer>(.*?)</answer>', r'\1', rest_of_string, flags=re.DOTALL)
94+
95+
return resp[:match.start()] + answer_content + cleaned_rest
96+
97+
# Process each response string
98+
return [process_single_response(resp) for resp in responses_str]
99+
100+
def _postprocess_responses(self, responses: torch.Tensor, envs: List[Any]) -> torch.Tensor:
101+
"""Process responses to remove 1. multiple answers or 2. reward hacking attempts."""
102+
responses_str = self.tokenizer.batch_decode(
103+
responses,
104+
skip_special_tokens=True
105+
)
106+
107+
# responses_str = [resp.split('</answer>')[0] + '</answer>'
108+
# if '</answer>' in resp else resp
109+
# for resp in responses_str]
110+
responses_str = self._process_answer_tag(responses_str)
111+
112+
if self.config.state_masking:
113+
# Escape special characters in markers for regex
114+
start_marker = re.escape(self.config.start_state_marker)
115+
end_marker = re.escape(self.config.end_state_marker)
116+
hack_pattern = f'{start_marker}[\\s\\S]*?{end_marker}'
117+
118+
hacked = [resp for resp in responses_str if re.search(hack_pattern, resp, re.DOTALL)]
119+
if hacked:
120+
print(f"[WARNING] HACKED RESPONSES: {hacked}")
121+
responses_str = [re.sub(hack_pattern, '', resp, re.DOTALL) for resp in responses_str]
122+
123+
if self.config.no_think_rl:
124+
# if no_think_rl is enabled, only keep action in the str
125+
actions, _ = self.env_class.postprocess_predictions(envs, responses_str)
126+
responses_str=[f"<answer>{envs[idx].ACTION_LOOKUP[action]}</answer>" for idx, action in enumerate(actions)]
127+
print("RESPONSES:", responses_str)
128+
responses = self._batch_tokenize(responses_str)
129+
return responses, responses_str
130+
131+
132+
133+
def _process_next_obs(self, next_obs: List[str]) -> torch.Tensor:
134+
"""Process next observations from environment."""
135+
if self.config.state_masking:
136+
start_marker = self.config.start_state_marker
137+
end_marker = self.config.end_state_marker
138+
139+
# Create inner versions by adding 'inner_' prefix
140+
inner_start = f"<inner_{start_marker[1:]}"
141+
inner_end = f"<inner_{end_marker[1:]}"
142+
143+
# Replace any existing markers with inner versions
144+
next_obs = [re.sub(re.escape(start_marker), inner_start, obs) for obs in next_obs]
145+
next_obs = [re.sub(re.escape(end_marker), inner_end, obs) for obs in next_obs]
146+
147+
# Wrap with state markers
148+
next_obs = [f"{start_marker}{obs}{end_marker}" for obs in next_obs]
149+
150+
next_obs_ids = self.tokenizer(
151+
next_obs,
152+
padding='longest',
153+
return_tensors='pt'
154+
)['input_ids']
155+
156+
if next_obs_ids.shape[1] > self.config.max_obs_length:
157+
print("[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG")
158+
next_obs_ids = next_obs_ids[:, :self.config.max_obs_length]
159+
160+
return next_obs_ids
161+
162+
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
163+
next_obs_ids: torch.Tensor) -> Dict:
164+
"""Update rolling state with new responses and observations."""
165+
# Concatenate and handle padding
166+
new_input_ids = self.tensor_fn.concatenate_with_padding([
167+
rollings.batch['input_ids'],
168+
cur_responses,
169+
next_obs_ids
170+
])
171+
172+
# Create attention mask and position ids
173+
new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids)
174+
new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask)
175+
176+
# Cut to appropriate length
177+
effective_len = new_attention_mask.sum(dim=1).max()
178+
max_len = min(self.config.max_prompt_length, effective_len)
179+
180+
return DataProto.from_dict({
181+
'input_ids': new_input_ids[:, -max_len:],
182+
'position_ids': new_position_ids[:, -max_len:],
183+
'attention_mask': new_attention_mask[:, -max_len:]
184+
})
185+
186+
def _update_right_side(self, right_side: Dict,
187+
cur_responses: torch.Tensor,
188+
next_obs_ids: torch.Tensor) -> Dict:
189+
"""Update right side state."""
190+
responses = self.tensor_fn.concatenate_with_padding([
191+
right_side['responses'],
192+
cur_responses,
193+
next_obs_ids
194+
], pad_to_left=False)
195+
196+
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
197+
max_len = min(self.config.max_prompt_length, effective_len)
198+
199+
return {'responses': responses[:, :max_len]}
200+
201+
202+
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
203+
"""
204+
Wrapper for generation that handles multi-GPU padding requirements.
205+
if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
206+
if active_batch size is not divisible by num_gpus, pad with first sequence
207+
then remove padding from output
208+
"""
209+
num_gpus = self.config.num_gpus
210+
if num_gpus <= 1:
211+
return self.actor_rollout_wg.generate_sequences(active_batch)
212+
213+
batch_size = active_batch.batch['input_ids'].shape[0]
214+
remainder = batch_size % num_gpus
215+
216+
if remainder == 0:
217+
return self.actor_rollout_wg.generate_sequences(active_batch)
218+
219+
# Add padding sequences
220+
padding_size = num_gpus - remainder
221+
padded_batch = {}
222+
223+
for k, v in active_batch.batch.items():
224+
# Use first sequence as padding template
225+
pad_sequence = v[0:1].repeat(padding_size, *[1] * (len(v.shape) - 1))
226+
padded_batch[k] = torch.cat([v, pad_sequence], dim=0)
227+
228+
padded_active_batch = DataProto.from_dict(padded_batch)
229+
230+
# Generate with padded batch
231+
padded_output = self.actor_rollout_wg.generate_sequences(padded_active_batch)
232+
233+
# Remove padding from output
234+
trimmed_batch = {k: v[:-padding_size] for k, v in padded_output.batch.items()}
235+
236+
# Handle meta_info if present
237+
if hasattr(padded_output, 'meta_info') and padded_output.meta_info:
238+
trimmed_meta = {}
239+
for k, v in padded_output.meta_info.items():
240+
if isinstance(v, torch.Tensor):
241+
trimmed_meta[k] = v[:-padding_size]
242+
else:
243+
trimmed_meta[k] = v
244+
padded_output.meta_info = trimmed_meta
245+
246+
padded_output.batch = trimmed_batch
247+
return padded_output
248+
249+
def run_llm_loop(self, gen_batch, envs: List[Any],
250+
initial_input_ids: torch.Tensor,
251+
output_dir: str,
252+
global_steps: int) -> Tuple[Dict, Dict]:
253+
"""Run main LLM generation loop."""
254+
# Setup visualization and Initialize states
255+
trajectory = self._setup_visualization()
256+
257+
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
258+
original_right_side = {'responses': initial_input_ids[:, []]}
259+
260+
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
261+
active_num_list = [active_mask.sum().item()]
262+
rollings = gen_batch
263+
264+
265+
# Main generation loop
266+
for step in range(self.config.max_turns):
267+
if not active_mask.sum():
268+
break
269+
rollings.batch = self.tensor_fn.cut_to_effective_len(
270+
rollings.batch,
271+
keys=['input_ids', 'attention_mask', 'position_ids']
272+
)
273+
274+
# gen_output = self.actor_rollout_wg.generate_sequences(rollings)
275+
rollings_active = DataProto.from_dict({
276+
k: v[active_mask] for k, v in rollings.batch.items()
277+
})
278+
gen_output = self._generate_with_gpu_padding(rollings_active)
279+
280+
meta_info = gen_output.meta_info
281+
responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses'],envs=envs)
282+
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
283+
284+
# Update visualization
285+
self._update_trajectory(trajectory, envs, responses_str, active_mask)
286+
287+
# Execute in environment and process observations
288+
next_obs, dones = self.env_class.execute_predictions(
289+
envs, responses_str, responses_ids, self.tokenizer
290+
)
291+
292+
active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
293+
active_num_list.append(active_mask.sum().item())
294+
next_obs_ids = self._process_next_obs(next_obs)
295+
296+
# Update states
297+
rollings = self._update_rolling_state(
298+
rollings,
299+
responses_ids,
300+
next_obs_ids
301+
)
302+
original_right_side = self._update_right_side(
303+
original_right_side,
304+
responses_ids,
305+
next_obs_ids
306+
)
307+
print("ACTIVE_TRAJ_NUM:", active_num_list)
308+
309+
# Save trajectory and return final output
310+
self._save_trajectory(trajectory, output_dir, global_steps)
311+
return self._compose_final_output(original_left_side, original_right_side, meta_info)
312+
313+
def _setup_visualization(self) -> List[Dict]:
314+
"""Setup visualization tracking if enabled."""
315+
if not self.config.logging.log_images:
316+
return None
317+
return [defaultdict(list) for _ in range(self.config.logging.log_n_image_per_batch)]
318+
319+
def _update_trajectory(self, trajectory: List[Dict],
320+
envs: List[Any], responses: List[str], active_mask: torch.Tensor):
321+
"""Update visualization trajectory if enabled."""
322+
if not trajectory:
323+
return
324+
n_visualize = self.config.logging.log_n_image_per_batch
325+
for idx, (env, active) in enumerate(zip(envs[:n_visualize], active_mask[:n_visualize])):
326+
if active:
327+
trajectory[idx]['state'].append(env.render('rgb_array'))
328+
329+
for idx, (response, env, active) in enumerate(zip(responses[:n_visualize],
330+
envs[:n_visualize],
331+
active_mask[:n_visualize])):
332+
if active:
333+
parsed = parse_llm_output(response, strategy="raw")
334+
335+
trajectory[idx]['answer'].append(response)
336+
trajectory[idx]['parsed_response'].append(parsed)
337+
338+
def _save_trajectory(self, trajectory: List[Dict],
339+
output_dir: str, global_steps: int):
340+
"""Save trajectory visualization if enabled."""
341+
if not trajectory:
342+
return
343+
344+
save_step_size = self.config.logging.log_image_step_size
345+
if not global_steps % save_step_size or self.is_validation:
346+
os.makedirs(output_dir, exist_ok=True)
347+
filenames = save_trajectory_to_output(trajectory, save_dir=output_dir)
348+
if 'wandb' in self.logger.logger:
349+
for filename in filenames:
350+
self.logger.logger['wandb'].save(filename)
351+
352+
353+
def _compose_final_output(self, left_side: Dict,
354+
right_side: Dict,
355+
meta_info: Dict) -> Tuple[Dict, Dict]:
356+
"""Compose final generation output."""
357+
final_output = right_side.copy()
358+
final_output['prompts'] = left_side['input_ids']
359+
360+
# Combine input IDs
361+
final_output['input_ids'] = torch.cat([
362+
left_side['input_ids'],
363+
right_side['responses']
364+
], dim=1)
365+
366+
# Create attention mask and position ids
367+
final_output['attention_mask'] = torch.cat([
368+
self.tensor_fn.create_attention_mask(left_side['input_ids']),
369+
self.tensor_fn.create_attention_mask(final_output['responses'])
370+
], dim=1)
371+
372+
final_output['position_ids'] = self.tensor_fn.create_position_ids(
373+
final_output['attention_mask']
374+
)
375+
376+
final_output = DataProto.from_dict(final_output)
377+
final_output.meta_info.update(meta_info)
378+
379+
return final_output

0 commit comments

Comments
 (0)