SFT distillation: bug fixes, VLM support, and pretokenization optimization#2053
SFT distillation: bug fixes, VLM support, and pretokenization optimization#2053
Conversation
…d CustomLossConfig Co-authored-by: will brown <willccbb@users.noreply.github.com>
# Conflicts: # CHANGELOG.md # src/prime_rl/orchestrator/orchestrator.py # src/prime_rl/orchestrator/scheduler.py # src/prime_rl/orchestrator/trajectories.py # src/prime_rl/trainer/rl/train.py # tests/unit/test_configs.py
full_ids was tokenized via _render_messages then immediately overwritten by build_incremental_token_mask, wasting a tokenization pass per step. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Duplicate of should_add_generation_prompt in utils/chat_template.py, never called anywhere. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Explicitly disable top-k and min-p filtering on all vLLM requests, not just the token client path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…inference Redundant with the Pydantic model validator definition. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When external teacher rollouts have tool_defs, convert them to OAI format and pass through to build_incremental_token_mask so that tokenization includes tool definition tokens from the chat template. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Matches the SFT data pipeline behavior. Without this, parallel tool call responses cause an assertion failure in incremental tokenization when the chat template re-renders consecutive tool messages. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Tool definitions arrive as plain dicts after ZMQ msgpack serialization (verifiers' msgpack_encoder calls model_dump() on Tool pydantic objects). Use duck-typed accessor to support both dict and object forms. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
For VLM models, the tokenizer alone produces only 1 <|image_pad|> token per image, but the model expects the count to match actual image patches. Thread the processor through pretokenize → render_messages → build_incremental_token_mask so that processor.apply_chat_template (which expands image placeholders correctly) is used when available. Also move pretokenize before build_vlm_image_cache in the orchestrator, since the cache build strips image data from messages. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
In the SFT distillation path, completion is always a single assistant message (enforced by verifiers' parse_response_message), so the completion mask is trivially all-True. Replace N incremental tokenizer/processor calls from build_incremental_token_mask with a single render_messages call. For VLM models this avoids redundant image preprocessing on every incremental prefix. Added assertion to guard the single-assistant-role assumption. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Autofix Details
Bugbot Autofix prepared a fix for the issue found in the latest run.
- ✅ Fixed: Five trivial wrapper functions add unnecessary indirection
- Removed the five pass-through wrappers in trajectories and updated internal/test call sites to use the imported chat_template functions directly.
Or push these changes by commenting:
@cursor push 50cd9ca098
Preview (50cd9ca098)
diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py
--- a/src/prime_rl/orchestrator/trajectories.py
+++ b/src/prime_rl/orchestrator/trajectories.py
@@ -45,39 +45,6 @@
zero_entry = [[0] * topk for _ in range(num_layers)]
return routed_experts + [zero_entry for _ in range(deficit)]
-
-def _common_prefix_len(a: list[int], b: list[int]) -> int:
- return common_prefix_len(a, b)
-
-
-def _normalize_messages(messages: Any, default_role: str) -> list[dict[str, Any]]:
- return normalize_messages(messages, default_role)
-
-
-def _deserialize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
- return deserialize_tool_calls(messages)
-
-
-def _strip_message_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
- return strip_message_content(messages)
-
-
-def _render_messages(
- tokenizer: PreTrainedTokenizer,
- messages: list[dict[str, Any]],
- add_generation_prompt: bool = False,
- tools: list[dict[str, Any]] | None = None,
- processor=None,
-) -> list[int]:
- return render_messages(
- tokenizer,
- messages,
- add_generation_prompt=add_generation_prompt,
- tools=tools,
- processor=processor,
- )
-
-
def _prepare_messages_for_processor(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages to the format expected by the VLM processor.
@@ -124,11 +91,11 @@
tools: list[dict[str, Any]] | None = None,
processor=None,
) -> dict[str, Any]:
- prompt = _normalize_messages(step.get("prompt"), default_role="user")
- completion = _normalize_messages(step.get("completion"), default_role="assistant")
+ prompt = normalize_messages(step.get("prompt"), default_role="user")
+ completion = normalize_messages(step.get("completion"), default_role="assistant")
- prompt = _strip_message_content(_deserialize_tool_calls(prompt))
- completion = _strip_message_content(_deserialize_tool_calls(completion))
+ prompt = strip_message_content(deserialize_tool_calls(prompt))
+ completion = strip_message_content(deserialize_tool_calls(completion))
assert all(m.get("role") == "assistant" for m in completion), (
"Expected all completion messages to be assistant role for SFT distillation, "
@@ -141,21 +108,21 @@
all_messages = prompt + completion
prompt_has_assistant_completion = len(completion) > 0 and completion[0].get("role") == "assistant"
- prompt_ids = _render_messages(
+ prompt_ids = render_messages(
tokenizer,
prompt,
add_generation_prompt=prompt_has_assistant_completion,
tools=tools,
processor=processor,
)
- full_ids = _render_messages(
+ full_ids = render_messages(
tokenizer,
all_messages,
tools=tools,
processor=processor,
)
- split_idx = _common_prefix_len(prompt_ids, full_ids)
+ split_idx = common_prefix_len(prompt_ids, full_ids)
completion_ids = full_ids[split_idx:]
completion_mask = [True] * len(completion_ids)
diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py
--- a/tests/unit/orchestrator/test_trajectories.py
+++ b/tests/unit/orchestrator/test_trajectories.py
@@ -10,13 +10,13 @@
from prime_rl.orchestrator.trajectories import (
VLMImageCache,
_align_routed_experts,
- _deserialize_tool_calls,
_extract_images_from_examples,
_extract_images_from_messages,
_ImageStore,
build_vlm_image_cache,
interleave_rollout,
)
+from prime_rl.utils.chat_template import deserialize_tool_calls
def _pixels(data: list[list[float]]) -> tuple[bytes, list[int]]:
@@ -33,7 +33,7 @@
def test_deserialize_tool_calls_does_not_inject_missing_key():
messages = [{"role": "assistant", "content": "hello"}]
- deserialized = _deserialize_tool_calls(messages)
+ deserialized = deserialize_tool_calls(messages)
assert "tool_calls" not in deserialized[0]
@@ -52,7 +52,7 @@
}
]
- deserialized = _deserialize_tool_calls(messages)
+ deserialized = deserialize_tool_calls(messages)
assert deserialized[0]["tool_calls"][0]["function"]["arguments"] == {"x": 1}This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.
| add_generation_prompt=add_generation_prompt, | ||
| tools=tools, | ||
| processor=processor, | ||
| ) |
There was a problem hiding this comment.
Five trivial wrapper functions add unnecessary indirection
Low Severity
_common_prefix_len, _normalize_messages, _deserialize_tool_calls, _strip_message_content, and _render_messages are single-line wrappers that purely delegate to identically-named functions already imported from prime_rl.utils.chat_template. The callers inside this file (and the test that imports _deserialize_tool_calls) could use the imported functions directly, removing five layers of unnecessary indirection.
samsja
left a comment
There was a problem hiding this comment.
lets remove all the custom config please. maybe keep it one
| if use_token_client: | ||
| sampling_args["logprobs"] = True | ||
| extra_body["return_token_ids"] = True | ||
|
|
||
| if extra_body: | ||
| sampling_args["extra_body"] = extra_body |
There was a problem hiding this comment.
we need logprob in both use token client true and false no ?
There was a problem hiding this comment.
added it but we need to make sure passing logprob = True doesn't break an api request to non-vllm api? and same with unconditionally passing in top_k and min_p
Keep examples/alphabet_sort/sft_distill_hard.toml as the canonical example config. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
logprobs=True was only set when use_token_client=True, but non-token- client paths (e.g. external APIs) also need logprobs for RL training. return_token_ids remains gated behind use_token_client since it is vLLM-specific. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When enable_policy_updates is False (SFT distillation), generate_batch skipped calling checkpoint_ready.set(). This worked by accident because the event starts set and is only cleared inside maybe_update_policy (which is never called in this path), but would deadlock if anything else ever cleared the event. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mikasenghaas
left a comment
There was a problem hiding this comment.
nice looks p clean overall
There was a problem hiding this comment.
are the changes here related to sft distillation on rl trainer?
There was a problem hiding this comment.
i think this came from @willccbb but seems to just be a refactor of helper functions that were duplicated between sft/data.py and orchestrator/trajectories.py into a new shared module utils/chat_template.py - should this be a separate pr?
| from prime_rl.utils.vlm import is_vlm_model | ||
|
|
||
|
|
||
| def setup_external_rollout_model(config: OrchestratorConfig, logger) -> tuple[Any, str, bool]: |
There was a problem hiding this comment.
would prefer not starting to have orch utils here
- Revert loss_scale regression: use loss_mask-based scaling unconditionally - Move setup_external_rollout_model to orchestrator/utils.py - Clarify log message to say "SFT distillation mode" Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…sequences When tokenizing trajectory steps from messages, the prompt-only and full (prompt+completion) tokenizations can diverge at a boundary (non-prefix split). Previously, the full prompt_ids was returned even when split_idx < len(prompt_ids), producing corrupted sequences where prompt_ids + completion_ids != full_ids. This broke interleaving (extension property checks fail on mismatched prefixes) and fed invalid token sequences to the trainer. Fix: use full_ids[:split_idx] as prompt_ids so the concatenation always equals full_ids exactly. Verified across GLM-4, Qwen3, Qwen2.5, Qwen2.5-VL tokenizers and tested end-to-end SFT distillation on alphabet-sort, wiki-search, tic-tac-toe, and color-codeword (VLM) environments. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 2 total unresolved issues (including 1 from previous review).
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
…ation The previous commit truncating prompt_ids to full_ids[:split_idx] made the debug log condition always false (split_idx < split_idx). Track the original prompt length before overwrite so the log fires correctly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>




Summary
Builds on top of #1905 (SFT distillation with teacher endpoint). This PR adds bug fixes and VLM multimodal support discovered during end-to-end testing of the SFT distillation pipeline across multiple environments.
Tested configurations:
Bug fixes
Handle dict
tool_defsin_convert_tools_to_oai_format— Tool definitions arrive as plain dicts after ZMQ msgpack serialization (verifiers'msgpack_encodercallsmodel_dump()on Tool pydantic objects). The function assumed attribute access (.name), now supports both dict and object forms.Clean spurious fields from VLM messages — Some environments (e.g., tic-tac-toe) include
image_url: Noneon text content items, which causes the Qwen3-VL processor to miscount images vs. image placeholder tokens._prepare_messages_for_processornow emits clean{"type": "text", "text": ...}dicts.VLM processor support in pretokenization
For VLM models, the tokenizer's
apply_chat_templateproduces only 1<|image_pad|>token per image, but the model expects the count to match actual image patches (e.g., 1900 for a larger image). The processor correctly expands these placeholders but wasn't wired into the pretokenization path.pretokenize_rollout_trajectory→_tokenize_step_from_messages→render_messages/build_incremental_token_mask_prepare_messages_for_processorto convertimage_urlitems to PIL Images and normalize message format for the processorpretokenize_rollout_trajectorybeforebuild_vlm_image_cachein the orchestrator (the cache build strips image data from messages)None(non-VLM models), the original tokenizer-only path runs unchangedPretokenization optimization
In the SFT distillation path, each trajectory step's completion is always a single assistant message (structurally enforced by verifiers'
parse_response_message). The previous code usedbuild_incremental_token_mask(N tokenizer/processor calls per step) to compute a mask that is trivially[True] * len(completion_ids). Replaced with a singlerender_messagescall + direct mask assignment. For VLM models, this avoids redundant image preprocessing on every incremental prefix. Added an assertion to guard the single-assistant-role assumption.Test plan
test_sft_trajectories.py)🤖 Generated with Claude Code
Note
Medium Risk
Changes orchestrator rollouts, sampling args, and trainer loss selection, including disabling weight broadcast/policy updates in a new external-rollout SFT mode; mistakes could break training/inference behavior or tokenization, especially for multimodal/tool-call trajectories.
Overview
Enables hard SFT distillation from an external OpenAI-compatible teacher endpoint via new
orchestrator.teacher_rollout_model, with config validation that enforcestrainer.loss.type = "sft",use_token_client = false, and no local[inference].Updates the orchestrator to route rollout generation through the external model when configured, disabling policy weight updates/weight broadcast and adjusting checkpoint-step handling accordingly.
Adds an
sftloss variant (SFTLossConfig+ masked-NLL implementation) and introduces a sharedchat_templateutility plus rollout pretokenization/reconstruction (including VLM image message handling and tool-call argument deserialization), with new unit tests and example/docs/config updates.Written by Cursor Bugbot for commit c95d9e1. This will update automatically on new commits. Configure here.