@@ -174,95 +174,6 @@ def _batch_tokenize(self, responses: List[str]) -> torch.Tensor:
174174 padding = "longest"
175175 )['input_ids' ]
176176
177- def _postprocess_responses (self , responses : torch .Tensor ) -> Tuple [torch .Tensor , List [str ]]:
178- """
179- Process responses to stop at tool call or final response.
180- Handles tags like <action> and </action> or <response> and </response>.
181-
182- Args:
183- responses: Tensor containing response token IDs
184-
185- Returns:
186- Tuple of (processed response tensor, processed response strings)
187- """
188- # Decode responses to strings
189- responses_str = self .tokenizer .batch_decode (
190- responses ,
191- skip_special_tokens = True
192- )
193-
194- # Process each response to extract action/response content
195- processed_responses = []
196- for resp in responses_str :
197- if '</action>' in resp :
198- # Stop at end of action
199- processed = resp .split ('</action>' )[0 ] + '</action>'
200- elif '</response>' in resp :
201- # Stop at end of response
202- processed = resp .split ('</response>' )[0 ] + '</response>'
203- else :
204- # No recognized end tag, keep as is
205- processed = resp
206- processed_responses .append (processed )
207-
208- # Re-tokenize processed responses
209- responses = self ._batch_tokenize (processed_responses )
210-
211- return responses , processed_responses
212-
213- def _process_next_obs (self , next_obs : List [str ]) -> torch .Tensor :
214- """
215- Process next observations from environment.
216- Tokenizes observations and handles maximum length constraints.
217-
218- Args:
219- next_obs: List of observation strings from the environment
220-
221- Returns:
222- Tensor of tokenized observations
223- """
224- # Tokenize observations with consistent padding
225- next_obs_ids = self .tokenizer (
226- next_obs ,
227- padding = 'longest' ,
228- return_tensors = 'pt' ,
229- add_special_tokens = False , # Prevents adding special tokens
230- )['input_ids' ]
231-
232- # Truncate if observations are too long
233- if next_obs_ids .shape [1 ] > self .config .max_obs_length :
234- print (f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, { next_obs_ids .shape [1 ]} & { self .config .max_obs_length } " )
235- # Truncate to max_obs_length
236- next_obs_ids = next_obs_ids [:, :self .config .max_obs_length ]
237-
238- return next_obs_ids
239-
240- def _update_rolling_state (self , rollings : DataProto , cur_responses : torch .Tensor ,
241- next_obs_ids : torch .Tensor ) -> DataProto :
242- """Update rolling state with new responses and observations."""
243- # Concatenate and handle padding
244- new_input_ids = self .tensor_fn .concatenate_with_padding ([
245- rollings .batch ['input_ids' ],
246- cur_responses ,
247- next_obs_ids
248- ])
249-
250- # Create attention mask and position ids
251- new_attention_mask = self .tensor_fn .create_attention_mask (new_input_ids )
252- new_position_ids = self .tensor_fn .create_position_ids (new_attention_mask )
253-
254- # Cut to appropriate length
255- effective_len = new_attention_mask .sum (dim = 1 ).max ()
256- max_len = min (self .config .max_prompt_length , effective_len )
257-
258- new_rollings = DataProto .from_dict ({
259- 'input_ids' : new_input_ids [:, - max_len :],
260- 'position_ids' : new_position_ids [:, - max_len :],
261- 'attention_mask' : new_attention_mask [:, - max_len :]
262- })
263- new_rollings .meta_info .update (rollings .meta_info )
264-
265- return new_rollings
266177
267178 def _run_single_rollout (self , initial_prompt_ids : torch .Tensor , task_idx : int , client : Any ) -> Dict [str , Any ]:
268179 """
0 commit comments