1+ import torch
2+ import re
3+ from collections import defaultdict
4+ import os
5+ from typing import List , Dict , Any , Tuple
6+ from dataclasses import dataclass
7+ from .tensor_helper import TensorHelper , TensorConfig
8+ from ragen .utils import set_seed
9+ from ragen .utils .plot import (
10+ save_trajectory_to_output ,
11+ parse_llm_output
12+ )
13+ from verl import DataProto
14+ from verl .utils .tracking import Tracking
15+ import shutil
16+
17+ @dataclass
18+ class GenerationConfig :
19+ max_turns : int
20+ max_start_length : int
21+ max_prompt_length : int
22+ max_response_length : int
23+ max_obs_length : int
24+ logging : dict
25+ num_gpus : int
26+ no_think_rl : bool = False
27+ state_masking : bool = False
28+ start_state_marker : str = "<start-state>"
29+ end_state_marker : str = "<end-state>"
30+
31+ class LLMGenerationManager :
32+ def __init__ (
33+ self ,
34+ tokenizer ,
35+ actor_rollout_wg ,
36+ env_class ,
37+ config : GenerationConfig ,
38+ logger : Tracking ,
39+ is_validation : bool = False ,
40+ ):
41+ self .tokenizer = tokenizer
42+ self .actor_rollout_wg = actor_rollout_wg
43+ self .env_class = env_class
44+ self .config = config
45+ self .logger = logger
46+ self .is_validation = is_validation
47+
48+ self .tensor_fn = TensorHelper (TensorConfig (
49+ pad_token_id = tokenizer .pad_token_id ,
50+ max_prompt_length = config .max_prompt_length ,
51+ max_obs_length = config .max_obs_length ,
52+ max_start_length = config .max_start_length
53+ ))
54+
55+ def _batch_tokenize (self , responses : List [str ]) -> torch .Tensor :
56+ """Tokenize a batch of responses."""
57+ return self .tokenizer (
58+ responses ,
59+ add_special_tokens = False ,
60+ return_tensors = 'pt' ,
61+ padding = "longest"
62+ )['input_ids' ]
63+
64+ @staticmethod
65+ def _process_answer_tag (responses_str ):
66+ """
67+ Process a list of response strings to keep only the first <answer></answer> tag pair
68+ while preserving the rest of the string content.
69+
70+ Args:
71+ responses_str (List[str]): List of response strings potentially containing answer tags
72+
73+ Returns:
74+ List[str]: Processed responses with only first answer tag pair preserved
75+ """
76+ def process_single_response (resp ):
77+ # If no answer tags present, return original string
78+ if '<answer>' not in resp or '</answer>' not in resp :
79+ return resp
80+
81+ # Find the first complete <answer> tag pair
82+ pattern = r'<answer>.*?</answer>'
83+ match = re .search (pattern , resp , re .DOTALL )
84+
85+ if not match :
86+ return resp
87+
88+ # Get the matched answer tag content
89+ answer_content = match .group (0 )
90+
91+ # Replace all subsequent answer tag pairs with their content
92+ rest_of_string = resp [match .end ():]
93+ cleaned_rest = re .sub (r'<answer>(.*?)</answer>' , r'\1' , rest_of_string , flags = re .DOTALL )
94+
95+ return resp [:match .start ()] + answer_content + cleaned_rest
96+
97+ # Process each response string
98+ return [process_single_response (resp ) for resp in responses_str ]
99+
100+ def _postprocess_responses (self , responses : torch .Tensor , envs : List [Any ]) -> torch .Tensor :
101+ """Process responses to remove 1. multiple answers or 2. reward hacking attempts."""
102+ responses_str = self .tokenizer .batch_decode (
103+ responses ,
104+ skip_special_tokens = True
105+ )
106+
107+ # responses_str = [resp.split('</answer>')[0] + '</answer>'
108+ # if '</answer>' in resp else resp
109+ # for resp in responses_str]
110+ responses_str = self ._process_answer_tag (responses_str )
111+
112+ if self .config .state_masking :
113+ # Escape special characters in markers for regex
114+ start_marker = re .escape (self .config .start_state_marker )
115+ end_marker = re .escape (self .config .end_state_marker )
116+ hack_pattern = f'{ start_marker } [\\ s\\ S]*?{ end_marker } '
117+
118+ hacked = [resp for resp in responses_str if re .search (hack_pattern , resp , re .DOTALL )]
119+ if hacked :
120+ print (f"[WARNING] HACKED RESPONSES: { hacked } " )
121+ responses_str = [re .sub (hack_pattern , '' , resp , re .DOTALL ) for resp in responses_str ]
122+
123+ if self .config .no_think_rl :
124+ # if no_think_rl is enabled, only keep action in the str
125+ actions , _ = self .env_class .postprocess_predictions (envs , responses_str )
126+ responses_str = [f"<answer>{ envs [idx ].ACTION_LOOKUP [action ]} </answer>" for idx , action in enumerate (actions )]
127+ print ("RESPONSES:" , responses_str )
128+ responses = self ._batch_tokenize (responses_str )
129+ return responses , responses_str
130+
131+
132+
133+ def _process_next_obs (self , next_obs : List [str ]) -> torch .Tensor :
134+ """Process next observations from environment."""
135+ if self .config .state_masking :
136+ start_marker = self .config .start_state_marker
137+ end_marker = self .config .end_state_marker
138+
139+ # Create inner versions by adding 'inner_' prefix
140+ inner_start = f"<inner_{ start_marker [1 :]} "
141+ inner_end = f"<inner_{ end_marker [1 :]} "
142+
143+ # Replace any existing markers with inner versions
144+ next_obs = [re .sub (re .escape (start_marker ), inner_start , obs ) for obs in next_obs ]
145+ next_obs = [re .sub (re .escape (end_marker ), inner_end , obs ) for obs in next_obs ]
146+
147+ # Wrap with state markers
148+ next_obs = [f"{ start_marker } { obs } { end_marker } " for obs in next_obs ]
149+
150+ next_obs_ids = self .tokenizer (
151+ next_obs ,
152+ padding = 'longest' ,
153+ return_tensors = 'pt'
154+ )['input_ids' ]
155+
156+ if next_obs_ids .shape [1 ] > self .config .max_obs_length :
157+ print ("[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG" )
158+ next_obs_ids = next_obs_ids [:, :self .config .max_obs_length ]
159+
160+ return next_obs_ids
161+
162+ def _update_rolling_state (self , rollings , cur_responses : torch .Tensor ,
163+ next_obs_ids : torch .Tensor ) -> Dict :
164+ """Update rolling state with new responses and observations."""
165+ # Concatenate and handle padding
166+ new_input_ids = self .tensor_fn .concatenate_with_padding ([
167+ rollings .batch ['input_ids' ],
168+ cur_responses ,
169+ next_obs_ids
170+ ])
171+
172+ # Create attention mask and position ids
173+ new_attention_mask = self .tensor_fn .create_attention_mask (new_input_ids )
174+ new_position_ids = self .tensor_fn .create_position_ids (new_attention_mask )
175+
176+ # Cut to appropriate length
177+ effective_len = new_attention_mask .sum (dim = 1 ).max ()
178+ max_len = min (self .config .max_prompt_length , effective_len )
179+
180+ return DataProto .from_dict ({
181+ 'input_ids' : new_input_ids [:, - max_len :],
182+ 'position_ids' : new_position_ids [:, - max_len :],
183+ 'attention_mask' : new_attention_mask [:, - max_len :]
184+ })
185+
186+ def _update_right_side (self , right_side : Dict ,
187+ cur_responses : torch .Tensor ,
188+ next_obs_ids : torch .Tensor ) -> Dict :
189+ """Update right side state."""
190+ responses = self .tensor_fn .concatenate_with_padding ([
191+ right_side ['responses' ],
192+ cur_responses ,
193+ next_obs_ids
194+ ], pad_to_left = False )
195+
196+ effective_len = self .tensor_fn .create_attention_mask (responses ).sum (dim = 1 ).max ()
197+ max_len = min (self .config .max_prompt_length , effective_len )
198+
199+ return {'responses' : responses [:, :max_len ]}
200+
201+
202+ def _generate_with_gpu_padding (self , active_batch : DataProto ) -> DataProto :
203+ """
204+ Wrapper for generation that handles multi-GPU padding requirements.
205+ if num_gpus <= 1, return self.actor_rollout_wg.generate_sequences(active_batch)
206+ if active_batch size is not divisible by num_gpus, pad with first sequence
207+ then remove padding from output
208+ """
209+ num_gpus = self .config .num_gpus
210+ if num_gpus <= 1 :
211+ return self .actor_rollout_wg .generate_sequences (active_batch )
212+
213+ batch_size = active_batch .batch ['input_ids' ].shape [0 ]
214+ remainder = batch_size % num_gpus
215+
216+ if remainder == 0 :
217+ return self .actor_rollout_wg .generate_sequences (active_batch )
218+
219+ # Add padding sequences
220+ padding_size = num_gpus - remainder
221+ padded_batch = {}
222+
223+ for k , v in active_batch .batch .items ():
224+ # Use first sequence as padding template
225+ pad_sequence = v [0 :1 ].repeat (padding_size , * [1 ] * (len (v .shape ) - 1 ))
226+ padded_batch [k ] = torch .cat ([v , pad_sequence ], dim = 0 )
227+
228+ padded_active_batch = DataProto .from_dict (padded_batch )
229+
230+ # Generate with padded batch
231+ padded_output = self .actor_rollout_wg .generate_sequences (padded_active_batch )
232+
233+ # Remove padding from output
234+ trimmed_batch = {k : v [:- padding_size ] for k , v in padded_output .batch .items ()}
235+
236+ # Handle meta_info if present
237+ if hasattr (padded_output , 'meta_info' ) and padded_output .meta_info :
238+ trimmed_meta = {}
239+ for k , v in padded_output .meta_info .items ():
240+ if isinstance (v , torch .Tensor ):
241+ trimmed_meta [k ] = v [:- padding_size ]
242+ else :
243+ trimmed_meta [k ] = v
244+ padded_output .meta_info = trimmed_meta
245+
246+ padded_output .batch = trimmed_batch
247+ return padded_output
248+
249+ def run_llm_loop (self , gen_batch , envs : List [Any ],
250+ initial_input_ids : torch .Tensor ,
251+ output_dir : str ,
252+ global_steps : int ) -> Tuple [Dict , Dict ]:
253+ """Run main LLM generation loop."""
254+ # Setup visualization and Initialize states
255+ trajectory = self ._setup_visualization ()
256+
257+ original_left_side = {'input_ids' : initial_input_ids [:, - self .config .max_start_length :]}
258+ original_right_side = {'responses' : initial_input_ids [:, []]}
259+
260+ active_mask = torch .ones (gen_batch .batch ['input_ids' ].shape [0 ], dtype = torch .bool )
261+ active_num_list = [active_mask .sum ().item ()]
262+ rollings = gen_batch
263+
264+
265+ # Main generation loop
266+ for step in range (self .config .max_turns ):
267+ if not active_mask .sum ():
268+ break
269+ rollings .batch = self .tensor_fn .cut_to_effective_len (
270+ rollings .batch ,
271+ keys = ['input_ids' , 'attention_mask' , 'position_ids' ]
272+ )
273+
274+ # gen_output = self.actor_rollout_wg.generate_sequences(rollings)
275+ rollings_active = DataProto .from_dict ({
276+ k : v [active_mask ] for k , v in rollings .batch .items ()
277+ })
278+ gen_output = self ._generate_with_gpu_padding (rollings_active )
279+
280+ meta_info = gen_output .meta_info
281+ responses_ids , responses_str = self ._postprocess_responses (gen_output .batch ['responses' ],envs = envs )
282+ responses_ids , responses_str = self .tensor_fn ._example_level_pad (responses_ids , responses_str , active_mask )
283+
284+ # Update visualization
285+ self ._update_trajectory (trajectory , envs , responses_str , active_mask )
286+
287+ # Execute in environment and process observations
288+ next_obs , dones = self .env_class .execute_predictions (
289+ envs , responses_str , responses_ids , self .tokenizer
290+ )
291+
292+ active_mask = torch .tensor ([not done for done in dones ], dtype = torch .bool )
293+ active_num_list .append (active_mask .sum ().item ())
294+ next_obs_ids = self ._process_next_obs (next_obs )
295+
296+ # Update states
297+ rollings = self ._update_rolling_state (
298+ rollings ,
299+ responses_ids ,
300+ next_obs_ids
301+ )
302+ original_right_side = self ._update_right_side (
303+ original_right_side ,
304+ responses_ids ,
305+ next_obs_ids
306+ )
307+ print ("ACTIVE_TRAJ_NUM:" , active_num_list )
308+
309+ # Save trajectory and return final output
310+ self ._save_trajectory (trajectory , output_dir , global_steps )
311+ return self ._compose_final_output (original_left_side , original_right_side , meta_info )
312+
313+ def _setup_visualization (self ) -> List [Dict ]:
314+ """Setup visualization tracking if enabled."""
315+ if not self .config .logging .log_images :
316+ return None
317+ return [defaultdict (list ) for _ in range (self .config .logging .log_n_image_per_batch )]
318+
319+ def _update_trajectory (self , trajectory : List [Dict ],
320+ envs : List [Any ], responses : List [str ], active_mask : torch .Tensor ):
321+ """Update visualization trajectory if enabled."""
322+ if not trajectory :
323+ return
324+ n_visualize = self .config .logging .log_n_image_per_batch
325+ for idx , (env , active ) in enumerate (zip (envs [:n_visualize ], active_mask [:n_visualize ])):
326+ if active :
327+ trajectory [idx ]['state' ].append (env .render ('rgb_array' ))
328+
329+ for idx , (response , env , active ) in enumerate (zip (responses [:n_visualize ],
330+ envs [:n_visualize ],
331+ active_mask [:n_visualize ])):
332+ if active :
333+ parsed = parse_llm_output (response , strategy = "raw" )
334+
335+ trajectory [idx ]['answer' ].append (response )
336+ trajectory [idx ]['parsed_response' ].append (parsed )
337+
338+ def _save_trajectory (self , trajectory : List [Dict ],
339+ output_dir : str , global_steps : int ):
340+ """Save trajectory visualization if enabled."""
341+ if not trajectory :
342+ return
343+
344+ save_step_size = self .config .logging .log_image_step_size
345+ if not global_steps % save_step_size or self .is_validation :
346+ os .makedirs (output_dir , exist_ok = True )
347+ filenames = save_trajectory_to_output (trajectory , save_dir = output_dir )
348+ if 'wandb' in self .logger .logger :
349+ for filename in filenames :
350+ self .logger .logger ['wandb' ].save (filename )
351+
352+
353+ def _compose_final_output (self , left_side : Dict ,
354+ right_side : Dict ,
355+ meta_info : Dict ) -> Tuple [Dict , Dict ]:
356+ """Compose final generation output."""
357+ final_output = right_side .copy ()
358+ final_output ['prompts' ] = left_side ['input_ids' ]
359+
360+ # Combine input IDs
361+ final_output ['input_ids' ] = torch .cat ([
362+ left_side ['input_ids' ],
363+ right_side ['responses' ]
364+ ], dim = 1 )
365+
366+ # Create attention mask and position ids
367+ final_output ['attention_mask' ] = torch .cat ([
368+ self .tensor_fn .create_attention_mask (left_side ['input_ids' ]),
369+ self .tensor_fn .create_attention_mask (final_output ['responses' ])
370+ ], dim = 1 )
371+
372+ final_output ['position_ids' ] = self .tensor_fn .create_position_ids (
373+ final_output ['attention_mask' ]
374+ )
375+
376+ final_output = DataProto .from_dict (final_output )
377+ final_output .meta_info .update (meta_info )
378+
379+ return final_output
0 commit comments