Skip to content

Commit ef9a2ec

Browse files
authored
[Fix] Fix some bugs in retool example (#1130)
1 parent cecba06 commit ef9a2ec

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

examples/retool/generate_with_retool.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,20 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
230230
tool_call_count = 0 # Track actual tool call rounds
231231

232232
for turn in range(TOOL_CONFIGS["max_turns"]):
233-
# Simple: just send prompt + response
233+
# Check if total length exceeds max context length
234+
total_length = len(prompt_tokens_ids) + len(response_token_ids)
235+
if args.rollout_max_context_len is not None:
236+
max_context_length = args.rollout_max_context_len
237+
else:
238+
max_context_length = args.context_parallel_size * args.max_tokens_per_gpu
239+
if total_length >= max_context_length:
240+
sample.status = Sample.Status.TRUNCATED
241+
break
242+
243+
# Use token IDs instead of text
244+
current_token_ids = prompt_tokens_ids + response_token_ids
234245
payload = {
235-
"text": prompt + response,
246+
"input_ids": current_token_ids,
236247
"sampling_params": sampling_params,
237248
"return_logprob": True, # Request log probabilities for training
238249
}
@@ -265,15 +276,16 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
265276
sample.status = Sample.Status.ABORTED
266277
return sample
267278

268-
cur_response = output["text"]
269-
270279
if "output_token_logprobs" in output["meta_info"]:
271280
cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
281+
cur_response = state.tokenizer.decode(cur_response_token_ids)
272282
cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
273283
if sample.rollout_log_probs is None:
274284
sample.rollout_log_probs = []
275285
sample.rollout_log_probs += cur_log_probs
286+
276287
else:
288+
cur_response = output["text"]
277289
cur_response = postprocess_responses(cur_response)
278290
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
279291

0 commit comments

Comments
 (0)