@@ -230,9 +230,20 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
230230 tool_call_count = 0 # Track actual tool call rounds
231231
232232 for turn in range (TOOL_CONFIGS ["max_turns" ]):
233- # Simple: just send prompt + response
233+ # Check if total length exceeds max context length
234+ total_length = len (prompt_tokens_ids ) + len (response_token_ids )
235+ if args .rollout_max_context_len is not None :
236+ max_context_length = args .rollout_max_context_len
237+ else :
238+ max_context_length = args .context_parallel_size * args .max_tokens_per_gpu
239+ if total_length >= max_context_length :
240+ sample .status = Sample .Status .TRUNCATED
241+ break
242+
243+ # Use token IDs instead of text
244+ current_token_ids = prompt_tokens_ids + response_token_ids
234245 payload = {
235- "text " : prompt + response ,
246+ "input_ids " : current_token_ids ,
236247 "sampling_params" : sampling_params ,
237248 "return_logprob" : True , # Request log probabilities for training
238249 }
@@ -265,15 +276,16 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
265276 sample .status = Sample .Status .ABORTED
266277 return sample
267278
268- cur_response = output ["text" ]
269-
270279 if "output_token_logprobs" in output ["meta_info" ]:
271280 cur_response_token_ids = [item [1 ] for item in output ["meta_info" ]["output_token_logprobs" ]]
281+ cur_response = state .tokenizer .decode (cur_response_token_ids )
272282 cur_log_probs = [item [0 ] for item in output ["meta_info" ]["output_token_logprobs" ]]
273283 if sample .rollout_log_probs is None :
274284 sample .rollout_log_probs = []
275285 sample .rollout_log_probs += cur_log_probs
286+
276287 else :
288+ cur_response = output ["text" ]
277289 cur_response = postprocess_responses (cur_response )
278290 cur_response_token_ids = state .tokenizer (cur_response , add_special_tokens = False )["input_ids" ]
279291
0 commit comments