Skip to content

Commit 5c67023

Browse files
authored
fix: Replace decode-based prefix matching with EOS-boundary splicing (#1337)
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
1 parent 15a0343 commit 5c67023

File tree

3 files changed

+328
-125
lines changed

3 files changed

+328
-125
lines changed

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 177 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch
2323
import uvicorn
2424
from fastapi import FastAPI
25-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2625

2726
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2827
from nemo_rl.distributed.virtual_cluster import _get_free_port_local, _get_node_ip_local
@@ -36,88 +35,90 @@
3635
from nemo_rl.models.generation.vllm.vllm_worker import BaseVllmGenerationWorker
3736

3837

39-
def _maybe_correct_merged_tokens(
40-
tokenizer: PreTrainedTokenizerBase,
41-
reference_token_ids: list[int],
42-
actual_token_ids: list[int],
38+
def _replace_prefix_tokens(
39+
tokenizer,
40+
model_prefix_token_ids: list[int],
41+
template_prefix_token_ids: list[int],
42+
template_token_ids: list[int],
4343
) -> list[int]:
44-
"""This is a subroutine used inside the vLLM Chat Completion server. Some environments (namely Penguin) require an OpenAI compatible server endpoint rather than an inference engine handle. This is fine for the most part, but it may cause issues when the environment is used as a part of training.
45-
46-
RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single trajectory, model generations in previous model calls may be re-tokenized to something that is different than what was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model produces are different enough for the differently re-tokenized generation result that it causes the training to be off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, which feels very strange during training.
47-
48-
Thus, in this function we attempt to correct any minor re-tokenization errors in an effort to stay on-policy as possible. We require the tokenizer, the ground truth reference token ids taken directly from previous model calls, and the re-tokenized actual token ids.
49-
50-
In other words, for the current model call:
51-
- reference_token_ids = all_prefill_so_far + new_generation
52-
- all_prefill_so_far: the last model call model engine input token ids. Literally what the model sees during the last generation call.
53-
- new_generation: the last model call model engine generated token ids. Literally what the model generates during the last generation call.
54-
- actual_token_ids = all_prefill_so_far_maybe_diff_tokenization + new_generation_maybe_diff_tokenization + tool_response_or_user + assistant_generation_prompt
55-
- all_prefill_so_far_maybe_diff_tokenization: the re-tokenized version of all_prefill_so_far. Since the token IDs in all_prefill_so_far were de-tokenized and returned as OpenAI schema, they must be re-tokenized for the current model call, which means that it may differ from all_prefill_so_far
56-
- new_generation_maybe_diff_tokenization: analogous version of all_prefill_so_far_maybe_diff_tokenization for new_generation
57-
- tool_response_or_user: some returned user or tool message. It doesn't matter that this is tokenized here since it has never been tokenized before. However, at the next model call, this will become part of the all_prefill_so_far.
58-
- assistant_generation_prompt: a common sequence of tokens to instruct the model to generate an assistant response.
59-
60-
The goal of this subroutine is to find the prefix in actual_token_ids that corresponds to the de-tokenized text of reference_token_ids.
61-
The idea of this subroutine implementation is to just de-tokenize subsequences of actual_token_ids (called candidate_token_ids) until the de-tokenized text matches the de-tokenized text of reference_token_ids.
62-
63-
TODO When NeMo RL supports training image generation models, we want to revisit and possibly update this function. This issue occurs when the model generates tokens that are de-tokenized into text or images, and then re-tokenized into tokens. So if there is a situation like that with images and image tokenization is non-unique, then we will need to uppdate this function.
44+
"""This is a subroutine used inside the vLLM Chat Completion server.
45+
46+
This function is for fixing up the chat template-tokenized messages history
47+
to match the model output tokenization up to the last assistant turn,
48+
in order to preserve the monotonic tokens property for optimized multi-turn
49+
training.
50+
51+
Some environments (namely Penguin) require an OpenAI compatible server
52+
endpoint rather than an inference engine handle. This is fine for the most
53+
part, but it may cause issues when the environment is used as a part of
54+
training.
55+
56+
RL training frameworks train models on token IDs, but the OpenAI compatible
57+
server communicates in what is basically de-tokenized text. When multiple
58+
model calls are made to the OpenAI compatible server in a single trajectory,
59+
model generations in previous model calls may be re-tokenized to something
60+
that is different than what was generated. This is not too big of an issue
61+
(that we know of) at inference time, but the log probs the model produces
62+
are different enough for the differently re-tokenized generation result that
63+
it causes the training to be off policy. Off policy isn't necessarily a bad
64+
thing in isolation, but this source of off-policyness may cause unexpected
65+
issues if not properly accounted for. It also mis-aligns the token ID
66+
sequences across model calls, which feels very strange during training.
67+
68+
There are real cases where the model output string _does not match_ the chat
69+
template tokenization of the parsed model output. A concrete example is
70+
inconsistent whitespace tokens around tool call special tokens.
71+
72+
TODO When NeMo RL supports training image generation models, we want to
73+
revisit and possibly update this function. This issue occurs when the model
74+
generates tokens that are de-tokenized into text or images, and then
75+
re-tokenized into tokens. So if there is a situation like that with images
76+
and image tokenization is non-unique, then we will need to uppdate this
77+
function.
78+
79+
Example (turn-by-turn, concise; eos_token_id = 2):
80+
Turn 1:
81+
- prefill_T1 (template prefill) = [11,12,13,40,41]
82+
- model output = [220,17,2] # decodes to " 4" + EOS
83+
- model_prefix_token_ids = prefill_T1 + model output
84+
=> [11,12,13,40,41,220,17,2]
85+
86+
Turn 2 (template retokenizes prior assistant text differently):
87+
- template_prefix_token_ids = [11,12,13,40,41,1001,2] # 1001 decodes to " 4"
88+
- template_token_ids = [11,12,13,40,41,1001,2,21,22,40,41]
89+
90+
_replace_prefix_tokens keeps the exact prior model tokens up to EOS and
91+
resumes from the template after that EOS:
92+
output => [11,12,13,40,41,220,17,2,21,22,40,41]
6493
"""
65-
if not reference_token_ids:
66-
return actual_token_ids
67-
68-
# No re-tokenization errors
69-
if reference_token_ids == actual_token_ids[: len(reference_token_ids)]:
70-
return actual_token_ids
71-
72-
reference_str, actual_str = tokenizer.batch_decode(
73-
[reference_token_ids, actual_token_ids]
94+
if not model_prefix_token_ids:
95+
return template_token_ids
96+
97+
eos_token_id = tokenizer.eos_token_id
98+
assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!"
99+
100+
model_cut_end = len(model_prefix_token_ids)
101+
if model_prefix_token_ids:
102+
# We are not always guaranteed that the model outputs an EOS token as the stop criteria of the previous model call e.g. when the model reaches max_tokens.
103+
# And since chat templates will always add one for us, we just cut the model input to right before the EOS token ID (if applicable)
104+
if model_prefix_token_ids[-1] == eos_token_id:
105+
model_cut_end -= 1
106+
107+
# We take everything starting with the EOS token ID.
108+
template_cut_start = -1
109+
for pos in reversed(range(len(template_prefix_token_ids))):
110+
if template_token_ids[pos] == eos_token_id:
111+
template_cut_start = pos
112+
break
113+
114+
# This should never be the case, but
115+
assert template_cut_start >= 0, (
116+
"No EOS token ID found in the chat-templated messages!"
74117
)
75118

76-
# For now, if a trajectory is not monotonically increasing, we assert.
77-
# Eventually when we support non-monotonic training, we need to update this logic
78-
assert (
79-
reference_str == actual_str[: len(reference_str)]
80-
), f"""Found a non-monotonically increasing trajectory that is not caused by a token merge on re-tokenization!
81-
Reference str: {reference_str}
82-
Actual str: {actual_str}
83-
84-
Reference token ids: {reference_token_ids}
85-
Actual token ids: {actual_token_ids}"""
86-
87-
# Now we want to try to find the subsequence of actual_token_ids that corresponds to reference_str
88-
# Our first guess is just the prefix in actual_token_ids of length reference_token_ids. How good of a guess this is depends on the distribution of the number of re-tokenization errors.
89-
# If there are a lot, this will be a poor guess. If there aren't that many this is a good guess.
90-
candidate_token_ids = actual_token_ids[: len(reference_token_ids)]
91-
candidate_str = tokenizer.decode(candidate_token_ids)
92-
93-
# If it's longer, we remove
94-
if len(candidate_str) > len(reference_str):
95-
while (
96-
candidate_str != reference_str
97-
and len(candidate_str) > len(reference_str)
98-
and candidate_token_ids
99-
):
100-
candidate_token_ids.pop()
101-
candidate_str = tokenizer.decode(candidate_token_ids)
102-
# If it's shorter we append
103-
elif len(candidate_str) < len(reference_str):
104-
while (
105-
candidate_str != reference_str
106-
and len(candidate_str) < len(reference_str)
107-
and len(candidate_token_ids) < len(actual_token_ids) - 1
108-
):
109-
candidate_token_ids.append(actual_token_ids[len(candidate_token_ids)])
110-
candidate_str = tokenizer.decode(candidate_token_ids)
111-
# If it's equal we should not need to do any modification. The assert below will directly error out.
112-
else:
113-
pass
114-
115-
# If we break above, it must be that we either found a correct match or that we didn't find a valid match
116-
# e.g. in cases where there is some token merging that occurs at the very end of the reference_token_ids
117-
# We scream loudly here.
118-
assert candidate_str == reference_str
119-
120-
return reference_token_ids + actual_token_ids[len(candidate_token_ids) :]
119+
return (
120+
model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:]
121+
)
121122

122123

123124
@ray.remote(
@@ -151,6 +152,9 @@ async def report_dp_openai_server_base_url(self) -> Optional[str]:
151152
return self.base_url
152153

153154
def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI:
155+
from copy import deepcopy
156+
from logging import Filter as LoggingFilter
157+
from logging import LogRecord
154158
from typing import List, Optional, Union
155159

156160
from fastapi import Request
@@ -169,6 +173,7 @@ def _setup_vllm_openai_api_server(self, app: FastAPI) -> FastAPI:
169173
TokenizeCompletionRequest,
170174
TokenizeResponse,
171175
)
176+
from vllm.v1.engine.async_llm import logger as vllm_async_llm_logger
172177

173178
engine_client = self.llm
174179
model_config = self.llm_async_engine_args.create_model_config()
@@ -214,6 +219,14 @@ async def _preprocess_chat(
214219
truncate_prompt_tokens=None,
215220
add_special_tokens=False,
216221
):
222+
# Materialize the message tool calls so we can deepcopy below.
223+
for message in messages:
224+
if message.get("tool_calls"):
225+
message["tool_calls"] = list(message["tool_calls"])
226+
227+
# Deepcopy messages here since _preprocess_chat may be destructive.
228+
messages_for_replace_prefix_tokens = deepcopy(messages)
229+
217230
# res is conversation, [request_prompt], [engine_prompt]
218231
res = await super()._preprocess_chat(
219232
request,
@@ -234,14 +247,50 @@ async def _preprocess_chat(
234247
if request.required_prefix_token_ids is None:
235248
return res
236249

250+
# Find the last assistant message
251+
last_assistant_message_idx = None
252+
for i in reversed(range(len(messages_for_replace_prefix_tokens))):
253+
if messages_for_replace_prefix_tokens[i]["role"] == "assistant":
254+
last_assistant_message_idx = i
255+
break
256+
257+
# If there's no assistant message, we don't have any issues.
258+
if last_assistant_message_idx is None:
259+
return res
260+
261+
# Include the last assistant message itself.
262+
messages_to_last_assistant_message = messages_for_replace_prefix_tokens[
263+
: last_assistant_message_idx + 1
264+
]
265+
# Call the actual preprocess chat subroutine so we don't miss anything. Whatever they do is whatever we do since we literally do what they do.
266+
corresponding_res = await super()._preprocess_chat(
267+
request,
268+
tokenizer,
269+
messages_to_last_assistant_message,
270+
chat_template,
271+
chat_template_content_format,
272+
add_generation_prompt=False,
273+
continue_final_message=False,
274+
tool_dicts=tool_dicts,
275+
documents=documents,
276+
chat_template_kwargs=chat_template_kwargs,
277+
tool_parser=tool_parser,
278+
truncate_prompt_tokens=truncate_prompt_tokens,
279+
add_special_tokens=add_special_tokens,
280+
)
281+
actual_corresponding_token_ids = corresponding_res[2][0][
282+
"prompt_token_ids"
283+
]
284+
237285
engine_prompt = res[2][
238286
0
239287
] # We need to modify engine_prompt.prompt_token_ids
240288

241-
final_prompt_token_ids = _maybe_correct_merged_tokens(
289+
final_prompt_token_ids = _replace_prefix_tokens(
242290
tokenizer=tokenizer,
243-
reference_token_ids=request.required_prefix_token_ids,
244-
actual_token_ids=engine_prompt["prompt_token_ids"],
291+
model_prefix_token_ids=request.required_prefix_token_ids,
292+
template_prefix_token_ids=request.required_prefix_token_ids,
293+
template_token_ids=engine_prompt["prompt_token_ids"],
245294
)
246295

247296
engine_prompt["prompt_token_ids"] = final_prompt_token_ids
@@ -330,7 +379,38 @@ class NeMoRLTokenizeChatRequest(
330379
class NeMoRLOpenAIServingTokenization(
331380
NeMoRLOpenAIServingMixin, OpenAIServingTokenization
332381
):
333-
pass
382+
async def create_tokenize(self, request, raw_request):
383+
"""Override to handle required_prefix_token_ids for tokenization."""
384+
# Call parent's create_tokenize first
385+
result = await super().create_tokenize(request, raw_request)
386+
387+
# If there's an error or no required_prefix_token_ids, return as-is
388+
if isinstance(result, ErrorResponse):
389+
return result
390+
391+
# Only process chat requests (not completion requests)
392+
if not hasattr(request, "messages"):
393+
return result
394+
395+
# Get the template-tokenized tokens from the result
396+
template_token_ids = result.tokens
397+
398+
# Get the tokenizer from the engine client
399+
tokenizer = await self.engine_client.get_tokenizer()
400+
401+
# Apply _replace_prefix_tokens to fix up the tokenization
402+
final_token_ids = _replace_prefix_tokens(
403+
tokenizer=tokenizer,
404+
model_prefix_token_ids=request.required_prefix_token_ids,
405+
template_prefix_token_ids=request.required_prefix_token_ids,
406+
template_token_ids=template_token_ids,
407+
)
408+
409+
# Update the result with the corrected tokens
410+
result.tokens = final_token_ids
411+
result.count = len(final_token_ids)
412+
413+
return result
334414

335415
openai_serving_tokenization = NeMoRLOpenAIServingTokenization(
336416
engine_client,
@@ -356,6 +436,20 @@ async def tokenize(request: NeMoRLTokenizeRequest, raw_request: Request):
356436
elif isinstance(generator, TokenizeResponse):
357437
return JSONResponse(content=generator.model_dump())
358438

439+
########################################
440+
# Logging
441+
########################################
442+
print(
443+
"Adding a vLLM logging filter so that the logs aren't spammed with `Added request ...` messages. This is to help errors pop up better and filter out noise."
444+
)
445+
446+
class NoAddedRequestFilter(LoggingFilter):
447+
def filter(self, record: LogRecord) -> bool:
448+
msg = record.getMessage()
449+
return "Added request" not in msg
450+
451+
vllm_async_llm_logger.addFilter(NoAddedRequestFilter())
452+
359453
return app
360454

361455
def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":

0 commit comments

Comments
 (0)