Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bc2670f
SFT distillation with teacher endpoint
willccbb Feb 26, 2026
d57726c
streamline SFT distillation rollout/eval model split
willccbb Mar 4, 2026
5ecd61b
fix sft reconstruction and default loss scaling regressions
willccbb Mar 4, 2026
0002457
simplify SFT distillation by removing eval inference split
willccbb Mar 4, 2026
25a50f5
chore: apply ruff formatting
willccbb Mar 5, 2026
06f352d
docs: add changelog entries for RolloutModelConfig, SFTLossConfig, an…
cursoragent Mar 5, 2026
54d306a
refine external SFT rollout path and cleanup tests
willccbb Mar 6, 2026
5644df3
refactor shared SFT chat-template tokenization logic
willccbb Mar 7, 2026
bace99c
Merge remote-tracking branch 'origin/main' into eli/sft-distillation
eligotts Mar 19, 2026
20fc8b0
fix: remove dead store in _tokenize_step_from_messages
eligotts Mar 19, 2026
14faa48
fix: remove unused _should_add_generation_prompt from trajectories
eligotts Mar 19, 2026
586ae0c
fix: pass top_k and min_p for both token and non-token client paths
eligotts Mar 19, 2026
08eb784
chore: remove redundant test_rl_config_external_rollout_mode_rejects_…
eligotts Mar 19, 2026
a889bb0
fix: correct CHANGELOG field name to teacher_rollout_model
eligotts Mar 19, 2026
5b7a31f
fix: pass tool definitions through pretokenize pipeline
eligotts Mar 19, 2026
a0ac2ab
fix: enable collapse_consecutive_tool_messages in pretokenization
eligotts Mar 19, 2026
5a64494
fix: handle dict tool_defs in _convert_tools_to_oai_format
eligotts Mar 19, 2026
05758a6
feat: VLM processor support in SFT distillation pretokenization
eligotts Mar 19, 2026
61b1d71
perf: replace incremental token mask with direct completion masking
eligotts Mar 19, 2026
4cac8fb
chore: remove duplicate SFT distillation configs
eligotts Mar 20, 2026
6e71436
fix: always request logprobs in sampling args
eligotts Mar 20, 2026
5472547
chore: remove test_rl_config_external_rollout_mode_rejects_token_client
eligotts Mar 20, 2026
25e7fca
fix: set checkpoint_ready in non-policy-update path
eligotts Mar 20, 2026
1001e11
Address PR review feedback
eligotts Mar 20, 2026
a0306cb
fix: truncate prompt_ids to common prefix to prevent corrupted token …
eligotts Mar 20, 2026
c95d9e1
fix: restore non-prefix tokenization debug log after prompt_ids trunc…
eligotts Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ Documenting changes which affect configuration usage patterns (added/moved/remov
- **`[deployment]` (inference)**: Added deployment configuration. `type = "single_node"` (default) with `gpus_per_node`. `type = "multi_node"` with `num_nodes` and `gpus_per_node` — requires `[slurm]` (2026-02-26)
- **`inference.output_dir`**: Added directory for SLURM logs and generated scripts (default: `"outputs"`) (2026-02-26)
- **`inference.dry_run`**: Added flag (default: `False`). When set, validates config, writes resolved config to `output_dir/configs/`, and exits without starting inference or submitting SLURM jobs (2026-02-26)
- **`orchestrator.teacher_rollout_model`**: Added optional external rollout model configuration. When set, rollouts are generated from this endpoint/model instead of the student inference server. Accepts a `TeacherRolloutModelConfig` with `client` and `model` sub-fields, or `None` (default: `None`) (2026-02-26)
- **`trainer.loss.type = "sft"`**: Added SFT loss variant. Set `trainer.loss.type = "sft"` to use masked negative log-likelihood loss instead of the default RL loss (2026-02-26)
- **`trainer.loss.type = "custom"`**: Added custom loss variant. Set `trainer.loss.type = "custom"` with `import_path` (e.g. `"my_module.my_loss"`) and optional `kwargs` to use an external loss function (2026-02-26)
- **`trainer.loss` (default loss)**: Made IPO (DPPO-Binary TV variant ([arxiv](https://arxiv.org/pdf/2602.04879)) + Kimi-K2.5 KL ([Kimi-K2.5](https://arxiv.org/pdf/2602.02276))) the default loss. Removed `ratio_type`, `token_mask_low`, `token_mask_high`, `sequence_clip_high`, `geo_mask_low`, `geo_mask_high`, `sequence_mask_low`, `sequence_mask_high`. Added `ipo_mask_low` (default: 0.2) and `ipo_mask_high` (default: 0.2) for token-level probability-difference masking. Changed `kl_tau` default from `0.0` to `1e-3`. (2026-03-02)
- **Metrics logging overhaul**: All orchestrator metrics now follow a `{metric}/{scope}/{stat}` naming convention where scope is `all` (global) or an env name. Per-env breakdowns are always logged (previously only when >1 env). Key renames: `reward/mean` → `reward/all/mean`, `batch/solve_none` → `solve_none/all`, `val_reward/` → `val/reward/`, `metrics/{name}` → `metrics/{env}/{name}`, `stop_condition/{sc}` → `stop_condition/all/{sc}`, `error/mean` → `error/all/mean`. New per-env metrics: `solve_none/{env}`, `solve_all/{env}`, `effective_batch_size/{env}`, `stop_condition/{env}/generation_truncated`, `stop_condition/{env}/{sc}`, `error/{env}/mean`, plus per-env breakdowns for `seq_len`, `prefill_len`, `decode_len`, `is_truncated`, `samples_per_rollout`, `num_turns`, `generation_ms`, `scoring_ms`. Removed: `reward/std`, `reward/median`, per-error-type breakdown (`error/{error_type}`). Env name `"all"` is reserved and rejected at config validation. Env name validation now strips `@version` suffixes to match runtime behavior (e.g. `math@1.0` and `math@2.0` correctly detected as duplicates). Solve stats now use `example_id` grouping instead of index arithmetic (pre-existing bug fix). (2026-03-08)
- **`wandb.shared`** (experimental): Added shared W&B mode that logs trainer and orchestrator metrics to a single W&B run instead of two separate runs. Uses `wandb.Settings(mode="shared")` (requires wandb SDK >= 0.19.9). Enabled by default on the RL entrypoint. Disable with `--wandb.shared False`. Run ID is communicated via `WANDB_SHARED_RUN_ID` env var; process role via `WANDB_SHARED_LABEL`. Non-primary processes retry `wandb.init` up to 30 times waiting for the primary to create the run. Works with multi-node SLURM. (2026-03-18)
Expand Down
29 changes: 29 additions & 0 deletions configs/ci/integration/alphabet_sort_branch/start.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
max_steps = 10
seq_len = 2048

[deployment]
type = "single_node"
num_infer_gpus = 1
num_train_gpus = 1

[ckpt]

[model]
name = "Qwen/Qwen3-0.6B"

[trainer.optim]
lr = 1e-5

[orchestrator]
batch_size = 128
rollouts_per_example = 8

[orchestrator.sampling]
max_tokens = 1024

[[orchestrator.env]]
id = "primeintellect/alphabet-sort"
name = "alphabet-sort"
args = { min_turns = 2, max_turns = 2, min_names_per_turn = 1, max_names_per_turn = 3, similarity_power = 4, power_per_turn = false }

[inference]
41 changes: 41 additions & 0 deletions docs/on_policy_distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,47 @@ enabled = false # Skip expensive verification

This runs pure on-policy distillation: the student learns to match the teacher without needing any reward signal.

## SFT Distillation ("Hard Distillation") From Teacher Rollouts

Use this mode when you want to train from teacher-generated completions directly (hard distillation), without teacher token-level logprobs.

```toml
[trainer.loss]
type = "sft"

[orchestrator]
use_token_client = false

[orchestrator.teacher_rollout_model.client]
base_url = ["https://your-openai-compatible-endpoint/v1"]
skip_model_check = true

[orchestrator.teacher_rollout_model.model]
name = "teacher-model-name"
```

In this mode:
- Rollouts are generated from `orchestrator.teacher_rollout_model`
- The orchestrator uses text-level reconstruction with the student tokenizer
- The RL trainer optimizes masked NLL (`trainer.loss.type = "sft"`)
- Omit `[inference]` (no local inference server required)

### Image Input (VLM) Support

Yes, image input is supported in SFT/hard-distillation mode when the student model is multimodal (VLM).

- Prompts can include OpenAI-style image items in `message.content`, e.g. `{"type": "image_url", "image_url": {"url": "data:image/..."}}`
- The orchestrator extracts and preprocesses images from trajectory prompts and attaches `pixel_values`/`image_grid_thw` to training samples
- No teacher token IDs/logprobs are required; reconstruction still happens from messages

Notes:
- This path currently expects `data:image/...` payloads in message content
- The teacher rollout endpoint still needs to be able to handle the same multimodal prompts during generation

Reference configs:
- `configs/alphabet_sort/sft_distill_hard_qwen4b_lora_prime_teacher.toml`
- `examples/alphabet_sort/sft_distill_hard.toml`

## Parameters

| Parameter | Default | Description |
Expand Down
51 changes: 51 additions & 0 deletions examples/alphabet_sort/sft_distill_hard.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
max_steps = 24
seq_len = 2048

[deployment]
type = "single_node"
gpus_per_node = 2
num_train_gpus = 2
num_infer_gpus = 0

[model]
name = "Qwen/Qwen3-4B-Instruct-2507"

[wandb]
project = "alphabet-sort-sft-distill"
name = "qwen3-4b-hard-distill"

[ckpt]

[trainer.optim]
lr = 1e-5

[trainer.loss]
type = "sft"

[trainer.model.lora]
rank = 16
alpha = 32

[trainer.ckpt.weights]
save_adapter_separately = true

[orchestrator]
batch_size = 256
rollouts_per_example = 4
use_token_client = false

[orchestrator.sampling]
max_tokens = 512
temperature = 0.7

[orchestrator.teacher_rollout_model.client]
base_url = ["https://api.pinference.ai/api/v1"]
api_key_var = "PRIME_API_KEY"

[orchestrator.teacher_rollout_model.model]
name = "qwen/qwen3-235b-a22b-instruct-2507"

[[orchestrator.env]]
id = "primeintellect/alphabet-sort"
name = "alphabet-sort"
args = { min_turns = 2, max_turns = 2, min_names_per_turn = 1, max_names_per_turn = 3, similarity_power = 4, power_per_turn = false }
26 changes: 26 additions & 0 deletions skills/toml-config/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,32 @@ SFT deployment follows the same pattern as RL:
- `output_dir` must be explicitly set when using SLURM
- Multi-node deployment requires `[slurm]` to be set

## SFT Distillation (Hard Distillation) With Teacher Rollouts

Use this when the teacher is an external OpenAI-compatible endpoint and you want to train from teacher completions directly (no teacher token logprobs required).

```toml
[trainer.loss]
type = "sft"

[orchestrator]
use_token_client = false

[orchestrator.teacher_rollout_model.client]
base_url = ["https://your-openai-compatible-endpoint/v1"]
skip_model_check = true

[orchestrator.teacher_rollout_model.model]
name = "teacher-model-name"
```

Notes:
- `orchestrator.teacher_rollout_model` switches rollout generation to the external teacher endpoint.
- `use_token_client = false` is required when `orchestrator.teacher_rollout_model` is set.
- `trainer.loss.type = "sft"` makes the RL trainer optimize masked NLL like SFT.
- In this mode, omit `[inference]`.
- Image input is supported when using a VLM student model and OpenAI-style image messages (`data:image/...`).

## Available commands

All accept `@ config.toml` and CLI overrides:
Expand Down
25 changes: 25 additions & 0 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,20 @@ class TeacherModelConfig(BaseConfig):
] = ModelConfig()


class TeacherRolloutModelConfig(BaseConfig):
"""Configures an external teacher model used to generate rollout text."""

client: Annotated[
ClientConfig,
Field(description="The OAI client configuration for rollout generation."),
] = ClientConfig()

model: Annotated[
ModelConfig,
Field(description="The model configuration for rollout generation."),
] = ModelConfig()


class OrchestratorConfig(BaseConfig):
"""Configures the orchestrator for RL training."""

Expand All @@ -688,6 +702,17 @@ class OrchestratorConfig(BaseConfig):
),
] = None

# External teacher rollout model configuration (optional)
teacher_rollout_model: Annotated[
TeacherRolloutModelConfig | None,
Field(
description=(
"Optional external teacher model used for rollout generation. "
"When set, rollouts are generated from this endpoint/model instead of the student inference server."
),
),
] = None

# The sampling configuration
sampling: SamplingConfig = SamplingConfig()

Expand Down
21 changes: 21 additions & 0 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,27 @@ def validate_teacher_model(self):
)
return self

@model_validator(mode="after")
def validate_external_rollout_mode(self):
if self.orchestrator.teacher_rollout_model is None:
return self

if self.trainer.loss.type != "sft":
raise ValueError('orchestrator.teacher_rollout_model is only supported when trainer.loss.type = "sft".')

if self.inference is not None:
raise ValueError(
"inference must be omitted when orchestrator.teacher_rollout_model is configured. "
"External rollout mode does not use the local inference server."
)

if self.orchestrator.use_token_client:
raise ValueError(
"orchestrator.use_token_client must be false when orchestrator.teacher_rollout_model is configured."
)

return self

### Auto-setup and validate shared configs

@model_validator(mode="after")
Expand Down
8 changes: 7 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ class DefaultLossConfig(BaseModel):
kl_tau: Annotated[float, Field(ge=0, description="The tau for KL divergence.")] = 1e-3


class SFTLossConfig(BaseModel):
"""Config for SFT-style masked negative log-likelihood loss."""

type: Literal["sft"] = "sft"


class CustomLossConfig(BaseModel):
"""Config for a custom external loss function."""

Expand All @@ -548,7 +554,7 @@ class CustomLossConfig(BaseModel):
kwargs: Annotated[dict[str, Any], Field(default_factory=dict, description="Kwargs to pass to the loss function")]


LossConfig: TypeAlias = Annotated[DefaultLossConfig | CustomLossConfig, Field(discriminator="type")]
LossConfig: TypeAlias = Annotated[DefaultLossConfig | SFTLossConfig | CustomLossConfig, Field(discriminator="type")]


class FakeDataLoaderConfig(BaseConfig):
Expand Down
11 changes: 8 additions & 3 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,14 @@ def rl_local(config: RLConfig):
monitor_thread.start()
monitor_threads.append(monitor_thread)
else:
logger.warning(
"No inference config specified, skipping starting inference server. Make sure your inference server is running."
)
if config.orchestrator.teacher_rollout_model is None:
logger.warning(
"No inference config specified, skipping starting inference server. Make sure your inference server is running."
)
else:
logger.info(
"No inference config specified, using orchestrator.teacher_rollout_model for rollout generation."
)

# Optionally, start teacher inference process
if config.teacher_inference:
Expand Down
Loading
Loading