Skip to content

SFT distillation: bug fixes, VLM support, and pretokenization optimization#2053

Open
eligotts wants to merge 26 commits intomainfrom
eli/sft-distillation
Open

SFT distillation: bug fixes, VLM support, and pretokenization optimization#2053
eligotts wants to merge 26 commits intomainfrom
eli/sft-distillation

Conversation

@eligotts
Copy link
Contributor

@eligotts eligotts commented Mar 19, 2026

Summary

Builds on top of #1905 (SFT distillation with teacher endpoint). This PR adds bug fixes and VLM multimodal support discovered during end-to-end testing of the SFT distillation pipeline across multiple environments.

Tested configurations:

  • Wiki-search — text-only, multi-turn tool calling, Claude Sonnet 4.6 teacher → Qwen3-4B student
  • Tic-tac-toe — VLM multimodal, multi-turn tool calling with board images, Claude teacher → Qwen3-VL-4B student
  • Color-codeword — VLM multimodal, multi-turn with dynamically generated images, no tools, Claude teacher → Qwen3-VL-4B student

Bug fixes

  • Handle dict tool_defs in _convert_tools_to_oai_format — Tool definitions arrive as plain dicts after ZMQ msgpack serialization (verifiers' msgpack_encoder calls model_dump() on Tool pydantic objects). The function assumed attribute access (.name), now supports both dict and object forms.

  • Clean spurious fields from VLM messages — Some environments (e.g., tic-tac-toe) include image_url: None on text content items, which causes the Qwen3-VL processor to miscount images vs. image placeholder tokens. _prepare_messages_for_processor now emits clean {"type": "text", "text": ...} dicts.

VLM processor support in pretokenization

For VLM models, the tokenizer's apply_chat_template produces only 1 <|image_pad|> token per image, but the model expects the count to match actual image patches (e.g., 1900 for a larger image). The processor correctly expands these placeholders but wasn't wired into the pretokenization path.

  • Thread the processor through pretokenize_rollout_trajectory_tokenize_step_from_messagesrender_messages / build_incremental_token_mask
  • Add _prepare_messages_for_processor to convert image_url items to PIL Images and normalize message format for the processor
  • Move pretokenize_rollout_trajectory before build_vlm_image_cache in the orchestrator (the cache build strips image data from messages)
  • When processor is None (non-VLM models), the original tokenizer-only path runs unchanged

Pretokenization optimization

In the SFT distillation path, each trajectory step's completion is always a single assistant message (structurally enforced by verifiers' parse_response_message). The previous code used build_incremental_token_mask (N tokenizer/processor calls per step) to compute a mask that is trivially [True] * len(completion_ids). Replaced with a single render_messages call + direct mask assignment. For VLM models, this avoids redundant image preprocessing on every incremental prefix. Added an assertion to guard the single-assistant-role assumption.

Test plan

  • Unit tests pass (test_sft_trajectories.py)
  • End-to-end wiki-search (text + tools) — 3 steps, loss decreasing
  • End-to-end tic-tac-toe (VLM + tools) — 10 steps, loss decreasing
  • End-to-end color-codeword (VLM, no tools) — 10 steps, loss decreasing
  • Verified old vs new mask path produces identical token IDs and masks
  • Confirmed non-VLM and regular RL paths are unaffected (processor=None, tokens already populated by vLLM)

🤖 Generated with Claude Code


Note

Medium Risk
Changes orchestrator rollouts, sampling args, and trainer loss selection, including disabling weight broadcast/policy updates in a new external-rollout SFT mode; mistakes could break training/inference behavior or tokenization, especially for multimodal/tool-call trajectories.

Overview
Enables hard SFT distillation from an external OpenAI-compatible teacher endpoint via new orchestrator.teacher_rollout_model, with config validation that enforces trainer.loss.type = "sft", use_token_client = false, and no local [inference].

Updates the orchestrator to route rollout generation through the external model when configured, disabling policy weight updates/weight broadcast and adjusting checkpoint-step handling accordingly.

Adds an sft loss variant (SFTLossConfig + masked-NLL implementation) and introduces a shared chat_template utility plus rollout pretokenization/reconstruction (including VLM image message handling and tool-call argument deserialization), with new unit tests and example/docs/config updates.

Written by Cursor Bugbot for commit c95d9e1. This will update automatically on new commits. Configure here.

willccbb and others added 19 commits March 3, 2026 22:53
…d CustomLossConfig

Co-authored-by: will brown <willccbb@users.noreply.github.com>
# Conflicts:
#	CHANGELOG.md
#	src/prime_rl/orchestrator/orchestrator.py
#	src/prime_rl/orchestrator/scheduler.py
#	src/prime_rl/orchestrator/trajectories.py
#	src/prime_rl/trainer/rl/train.py
#	tests/unit/test_configs.py
full_ids was tokenized via _render_messages then immediately overwritten
by build_incremental_token_mask, wasting a tokenization pass per step.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Duplicate of should_add_generation_prompt in utils/chat_template.py,
never called anywhere.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Explicitly disable top-k and min-p filtering on all vLLM requests,
not just the token client path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…inference

Redundant with the Pydantic model validator definition.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When external teacher rollouts have tool_defs, convert them to OAI
format and pass through to build_incremental_token_mask so that
tokenization includes tool definition tokens from the chat template.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Matches the SFT data pipeline behavior. Without this, parallel tool
call responses cause an assertion failure in incremental tokenization
when the chat template re-renders consecutive tool messages.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Tool definitions arrive as plain dicts after ZMQ msgpack serialization
(verifiers' msgpack_encoder calls model_dump() on Tool pydantic objects).
Use duck-typed accessor to support both dict and object forms.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
For VLM models, the tokenizer alone produces only 1 <|image_pad|> token
per image, but the model expects the count to match actual image patches.
Thread the processor through pretokenize → render_messages →
build_incremental_token_mask so that processor.apply_chat_template
(which expands image placeholders correctly) is used when available.

Also move pretokenize before build_vlm_image_cache in the orchestrator,
since the cache build strips image data from messages.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
In the SFT distillation path, completion is always a single assistant
message (enforced by verifiers' parse_response_message), so the
completion mask is trivially all-True. Replace N incremental
tokenizer/processor calls from build_incremental_token_mask with a
single render_messages call. For VLM models this avoids redundant
image preprocessing on every incremental prefix.

Added assertion to guard the single-assistant-role assumption.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Autofix Details

Bugbot Autofix prepared a fix for the issue found in the latest run.

  • ✅ Fixed: Five trivial wrapper functions add unnecessary indirection
    • Removed the five pass-through wrappers in trajectories and updated internal/test call sites to use the imported chat_template functions directly.

Create PR

Or push these changes by commenting:

@cursor push 50cd9ca098
Preview (50cd9ca098)
diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py
--- a/src/prime_rl/orchestrator/trajectories.py
+++ b/src/prime_rl/orchestrator/trajectories.py
@@ -45,39 +45,6 @@
     zero_entry = [[0] * topk for _ in range(num_layers)]
     return routed_experts + [zero_entry for _ in range(deficit)]
 
-
-def _common_prefix_len(a: list[int], b: list[int]) -> int:
-    return common_prefix_len(a, b)
-
-
-def _normalize_messages(messages: Any, default_role: str) -> list[dict[str, Any]]:
-    return normalize_messages(messages, default_role)
-
-
-def _deserialize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
-    return deserialize_tool_calls(messages)
-
-
-def _strip_message_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
-    return strip_message_content(messages)
-
-
-def _render_messages(
-    tokenizer: PreTrainedTokenizer,
-    messages: list[dict[str, Any]],
-    add_generation_prompt: bool = False,
-    tools: list[dict[str, Any]] | None = None,
-    processor=None,
-) -> list[int]:
-    return render_messages(
-        tokenizer,
-        messages,
-        add_generation_prompt=add_generation_prompt,
-        tools=tools,
-        processor=processor,
-    )
-
-
 def _prepare_messages_for_processor(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
     """Convert messages to the format expected by the VLM processor.
 
@@ -124,11 +91,11 @@
     tools: list[dict[str, Any]] | None = None,
     processor=None,
 ) -> dict[str, Any]:
-    prompt = _normalize_messages(step.get("prompt"), default_role="user")
-    completion = _normalize_messages(step.get("completion"), default_role="assistant")
+    prompt = normalize_messages(step.get("prompt"), default_role="user")
+    completion = normalize_messages(step.get("completion"), default_role="assistant")
 
-    prompt = _strip_message_content(_deserialize_tool_calls(prompt))
-    completion = _strip_message_content(_deserialize_tool_calls(completion))
+    prompt = strip_message_content(deserialize_tool_calls(prompt))
+    completion = strip_message_content(deserialize_tool_calls(completion))
 
     assert all(m.get("role") == "assistant" for m in completion), (
         "Expected all completion messages to be assistant role for SFT distillation, "
@@ -141,21 +108,21 @@
 
     all_messages = prompt + completion
     prompt_has_assistant_completion = len(completion) > 0 and completion[0].get("role") == "assistant"
-    prompt_ids = _render_messages(
+    prompt_ids = render_messages(
         tokenizer,
         prompt,
         add_generation_prompt=prompt_has_assistant_completion,
         tools=tools,
         processor=processor,
     )
-    full_ids = _render_messages(
+    full_ids = render_messages(
         tokenizer,
         all_messages,
         tools=tools,
         processor=processor,
     )
 
-    split_idx = _common_prefix_len(prompt_ids, full_ids)
+    split_idx = common_prefix_len(prompt_ids, full_ids)
 
     completion_ids = full_ids[split_idx:]
     completion_mask = [True] * len(completion_ids)

diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py
--- a/tests/unit/orchestrator/test_trajectories.py
+++ b/tests/unit/orchestrator/test_trajectories.py
@@ -10,13 +10,13 @@
 from prime_rl.orchestrator.trajectories import (
     VLMImageCache,
     _align_routed_experts,
-    _deserialize_tool_calls,
     _extract_images_from_examples,
     _extract_images_from_messages,
     _ImageStore,
     build_vlm_image_cache,
     interleave_rollout,
 )
+from prime_rl.utils.chat_template import deserialize_tool_calls
 
 
 def _pixels(data: list[list[float]]) -> tuple[bytes, list[int]]:
@@ -33,7 +33,7 @@
 def test_deserialize_tool_calls_does_not_inject_missing_key():
     messages = [{"role": "assistant", "content": "hello"}]
 
-    deserialized = _deserialize_tool_calls(messages)
+    deserialized = deserialize_tool_calls(messages)
 
     assert "tool_calls" not in deserialized[0]
 
@@ -52,7 +52,7 @@
         }
     ]
 
-    deserialized = _deserialize_tool_calls(messages)
+    deserialized = deserialize_tool_calls(messages)
 
     assert deserialized[0]["tool_calls"][0]["function"]["arguments"] == {"x": 1}

This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

add_generation_prompt=add_generation_prompt,
tools=tools,
processor=processor,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Five trivial wrapper functions add unnecessary indirection

Low Severity

_common_prefix_len, _normalize_messages, _deserialize_tool_calls, _strip_message_content, and _render_messages are single-line wrappers that purely delegate to identically-named functions already imported from prime_rl.utils.chat_template. The callers inside this file (and the test that imports _deserialize_tool_calls) could use the imported functions directly, removing five layers of unnecessary indirection.

Fix in Cursor Fix in Web

Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove all the custom config please. maybe keep it one

Comment on lines +60 to +65
if use_token_client:
sampling_args["logprobs"] = True
extra_body["return_token_ids"] = True

if extra_body:
sampling_args["extra_body"] = extra_body
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need logprob in both use token client true and false no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it but we need to make sure passing logprob = True doesn't break an api request to non-vllm api? and same with unconditionally passing in top_k and min_p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2026-03-19 at 5 36 42 PM

fyi

Keep examples/alphabet_sort/sft_distill_hard.toml as the canonical
example config.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
eligotts and others added 2 commits March 20, 2026 00:11
logprobs=True was only set when use_token_client=True, but non-token-
client paths (e.g. external APIs) also need logprobs for RL training.
return_token_ids remains gated behind use_token_client since it is
vLLM-specific.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When enable_policy_updates is False (SFT distillation), generate_batch
skipped calling checkpoint_ready.set(). This worked by accident because
the event starts set and is only cleared inside maybe_update_policy
(which is never called in this path), but would deadlock if anything
else ever cleared the event.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Member

@mikasenghaas mikasenghaas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice looks p clean overall

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the changes here related to sft distillation on rl trainer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this came from @willccbb but seems to just be a refactor of helper functions that were duplicated between sft/data.py and orchestrator/trajectories.py into a new shared module utils/chat_template.py - should this be a separate pr?

from prime_rl.utils.vlm import is_vlm_model


def setup_external_rollout_model(config: OrchestratorConfig, logger) -> tuple[Any, str, bool]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer not starting to have orch utils here

- Revert loss_scale regression: use loss_mask-based scaling unconditionally
- Move setup_external_rollout_model to orchestrator/utils.py
- Clarify log message to say "SFT distillation mode"

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…sequences

When tokenizing trajectory steps from messages, the prompt-only and
full (prompt+completion) tokenizations can diverge at a boundary
(non-prefix split). Previously, the full prompt_ids was returned even
when split_idx < len(prompt_ids), producing corrupted sequences where
prompt_ids + completion_ids != full_ids. This broke interleaving
(extension property checks fail on mismatched prefixes) and fed
invalid token sequences to the trainer.

Fix: use full_ids[:split_idx] as prompt_ids so the concatenation
always equals full_ids exactly.

Verified across GLM-4, Qwen3, Qwen2.5, Qwen2.5-VL tokenizers and
tested end-to-end SFT distillation on alphabet-sort, wiki-search,
tic-tac-toe, and color-codeword (VLM) environments.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

…ation

The previous commit truncating prompt_ids to full_ids[:split_idx] made
the debug log condition always false (split_idx < split_idx). Track the
original prompt length before overwrite so the log fires correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants