Skip to content

Commit 357c6ad

Browse files
committed
Merge remote-tracking branch 'origin/main' into chunked-lm-head
2 parents c5bfedb + e923a9a commit 357c6ad

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

trl/experimental/async_grpo/async_rollout_worker.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -546,27 +546,28 @@ async def _generate_one(
546546
tool_call_count += n_calls
547547
tool_failure_count += n_failures
548548
completion.extend(tool_messages)
549-
tool_suffix_ids = self._build_messages_suffix_ids(tool_messages)
550-
completion_ids.extend(tool_suffix_ids)
551-
completion_logprobs.extend([0.0] * len(tool_suffix_ids))
552-
tool_mask.extend([0] * len(tool_suffix_ids))
553-
prompt_ids = prompt_ids + turn_ids + tool_suffix_ids
549+
suffix_ids = self._get_tool_suffix_ids(tool_messages)
550+
completion_ids.extend(suffix_ids)
551+
completion_logprobs.extend([0.0] * len(suffix_ids))
552+
tool_mask.extend([0] * len(suffix_ids))
553+
prompt_ids = prompt_ids + turn_ids + suffix_ids
554554
iteration_num += 1
555555

556-
def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int]:
557-
template_messages = [
556+
def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]:
557+
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
558+
dummy_messages = [
558559
{"role": "user", "content": ""},
559560
{"role": "assistant", "content": ""},
560561
]
561562
prefix_ids = self.tokenizer.apply_chat_template(
562-
template_messages,
563+
dummy_messages,
563564
return_dict=False,
564565
tools=self.tools,
565566
chat_template=self.chat_template,
566567
**self.chat_template_kwargs,
567568
)
568-
prefix_and_messages_ids = self.tokenizer.apply_chat_template(
569-
template_messages + messages,
569+
full_ids = self.tokenizer.apply_chat_template(
570+
dummy_messages + tool_messages,
570571
return_dict=False,
571572
chat_template=self.chat_template,
572573
add_generation_prompt=True,
@@ -575,15 +576,15 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int
575576
)
576577

577578
# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
578-
# When we compute `suffix_ids` by slicing `prefix_and_messages_ids`, we must align the slicing boundary to
579+
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
579580
# EOS (not EOS + newline).
580581
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id)
581582
prefix_ids = prefix_ids[: last_eos_idx + 1]
582583

583-
if prefix_and_messages_ids[: len(prefix_ids)] != prefix_ids:
584+
if full_ids[: len(prefix_ids)] != prefix_ids:
584585
raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")
585586

586-
return prefix_and_messages_ids[len(prefix_ids) :]
587+
return full_ids[len(prefix_ids) :]
587588

588589
def _execute_tool_calls(
589590
self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable]

0 commit comments

Comments
 (0)