diff --git a/apps/on_policy_distillation/README.md b/apps/on_policy_distillation/README.md new file mode 100644 index 000000000..7d0b26bfc --- /dev/null +++ b/apps/on_policy_distillation/README.md @@ -0,0 +1,77 @@ +# On-Policy Distillation for Math Reasoning + +This app implements on-policy distillation (OPD) following the approach described in the [Thinking Machines blog post](https://thinkingmachines.ai/blog/on-policy-distillation/). OPD combines the benefits of on-policy training with dense reward signals for efficient post-training. + +## Overview + +On-policy distillation trains a student model by: +1. Sampling trajectories from the student model itself +2. Using a teacher model to grade each token with dense rewards (per-token KL divergence) +3. Training the student to minimize reverse KL with the teacher + +This approach is **10-30x more compute efficient** than traditional RL while achieving comparable or better performance. + +## Experimental Setup + +### Models +- **Student**: Qwen3-1.7B-Base (or Qwen3-8B for larger experiments) +- **Teacher**: Qwen3-8B (or Qwen3-32B) +- **Evaluation**: AIME'24 benchmark + +### Training Pipeline + +#### Phase 1: Supervised Fine-Tuning (SFT) +First, establish a strong baseline through off-policy distillation: + +```bash +python -m apps.sft.main --config apps/sft/qwen3_1_7b.yaml +``` + +- **Dataset**: OpenThoughts3-1.2M (400k prompts) +- **Expected Performance**: ~40% on AIME'24 +- **Purpose**: Teaches the model basic math reasoning patterns + +#### Phase 2: On-Policy Distillation +Refine the model using on-policy learning with dense supervision: + +```bash +python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_opd.yaml +``` + +- **Starting Point**: SFT checkpoint from Phase 1 +- **Dataset**: Math prompts (from OpenThoughts3 or DeepMath, but only prompts - not solutions) +- **Training**: ~150-200 steps (77k prompts with 4 samples each) +- **Expected Performance**: ~50% on AIME'24 + +### Key Implementation Details + +1. **Loss Function**: Per-token reverse KL divergence + ```python + reverse_kl = -(student_logprobs - teacher_logprobs) + ``` + +2. **Sampling**: Generate multiple trajectories per prompt (n=16 in config) + +3. **No Discount Factor**: Optimize only immediate next token (discount=0) + +4. **Efficient Batching**: Can use smaller batch sizes than RL due to dense rewards + +## Key Advantages + +- **Compute Efficiency**: 10-30x reduction vs traditional RL +- **Dense Supervision**: Learns from every token, not just final rewards +- **Data Efficiency**: Can reuse prompts multiple times effectively +- **Stability**: More stable training than sparse RL rewards + +## Notes for Reproduction + +1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD +2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions +3. **Teacher quality matters**: Better teachers provide better supervision +4. **Monitor reverse KL**: Should go to near-zero as training progresses + +## References + +- [On-Policy Distillation Blog Post](https://thinkingmachines.ai/blog/on-policy-distillation/) +- [Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook) +- [OpenThoughts3 Dataset](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M) diff --git a/apps/on_policy_distillation/default.yaml b/apps/on_policy_distillation/default.yaml new file mode 100644 index 000000000..259f8a547 --- /dev/null +++ b/apps/on_policy_distillation/default.yaml @@ -0,0 +1,120 @@ +# On-Policy Distillation: Qwen 1.7B (student) learning from Qwen 8B (teacher) +# >>> python -m apps.on_policy_distillation.main --config apps/on_policy_distillation/qwen_1_7b_to_8b.yaml + +# Global configuration +train_batch_size: 16 # Number of trajectories per training step +max_req_tokens: 2048 +max_res_tokens: 4096 +student_model: "./Qwen3-1.7B-Base-SFT" # Path to base model SFT'd on a math dataset +teacher_model: "Qwen/Qwen3-8B" + +# Observability configuration +metric_logging: + wandb: + project: opd-training + group: opd_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + console: + logging_mode: global_reduce + +# Dataset configuration +dataset: + path: "zwhe99/DeepMath-103K" + split: "train" + +# Student generation configuration +student_generator: + engine_args: + model: ${student_model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_params: + n: 4 + max_tokens: ${max_res_tokens} + temperature: 0.6 + top_p: 0.95 + +# Student training configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${student_model} + optimizer: + name: AdamW + lr: 5e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 0 + training: + local_batch_size: ${train_batch_size} # Per-device batch size + seq_len: 8192 + max_norm: 1.0 + steps: 200 + dtype: bfloat16 + gc_freq: 5 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint-opd + initial_load_path: ${student_model} + initial_load_model_only: true + initial_load_in_hf: true + last_save_in_hf: true + interval: 50 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Teacher model configuration +teacher: + model: + name: qwen3 + flavor: 8B + hf_assets_path: hf://${teacher_model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 10 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${teacher_model} + initial_load_in_hf: true + +# Resource allocations +services: + student_generator: + procs: 1 + num_replicas: 4 + mesh_name: student_generator + with_gpus: true + teacher: + procs: 1 + num_replicas: 2 + mesh_name: teacher + with_gpus: true + trainer: + procs: 1 + num_replicas: 1 + mesh_name: trainer + with_gpus: true diff --git a/apps/on_policy_distillation/main.py b/apps/on_policy_distillation/main.py new file mode 100644 index 000000000..991058edd --- /dev/null +++ b/apps/on_policy_distillation/main.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import itertools +import time +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn.functional as F +import torchstore as ts +from datasets import load_dataset +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.trainer import TitanTrainer +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.util.config import parse +from forge.util.ops import compute_logprobs +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +@dataclass +class Trajectory: + pad_id: int + request_len: int + response_len: int + completion: Completion | None = None + teacher_logprobs: torch.Tensor | None = None + + @property + def request_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + elif tensor.shape[0] > self.request_len: # truncate + tensor = tensor[-self.request_len :] + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.token_ids.to(torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + elif tensor.shape[0] > self.response_len: # truncate + tensor = tensor[: self.response_len] + return tensor + + +def collate( + batches: list[list[Trajectory]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + inputs = [] + targets = [] + for batch in batches: + request = [t.request_tensor for t in batch] + request = torch.stack(request) + + response = [t.response_tensor for t in batch] + response = torch.stack(response) + + teacher_logprobs = [t.teacher_logprobs for t in batch] + teacher_logprobs = torch.stack(teacher_logprobs) + + pad_id = batch[0].pad_id + padding_mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "teacher_logprobs": teacher_logprobs, + "padding_mask": padding_mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +def reverse_kl_loss( + logits: torch.Tensor, + response: torch.Tensor, + teacher_logprobs: torch.Tensor, + padding_mask: torch.Tensor, + **kwargs, +) -> torch.Tensor: + student_logprobs = compute_logprobs(logits, response) + + reverse_kl = student_logprobs.detach() - teacher_logprobs + advantages = -reverse_kl + + per_token = -(advantages * student_logprobs) * padding_mask + loss = per_token.sum() / padding_mask.sum().clamp(min=1) + + return loss.mean() + + +async def main(cfg: DictConfig): + train_batch_size = cfg.train_batch_size + max_steps = cfg.trainer.training.steps + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + + provisioner = await init_provisioner() + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(cfg.metric_logging) + student_trainer, student_generator, teacher = await asyncio.gather( + TitanTrainer.options(**cfg.services.trainer).as_actor( + **cfg.trainer, loss=reverse_kl_loss + ), + Generator.options(**cfg.services.student_generator).as_service( + **cfg.student_generator + ), + ReferenceModel.options(**cfg.services.teacher).as_service(**cfg.teacher), + ) + + # Initialize torchstore for weight management + trainer_num_procs = cfg.services.trainer["procs"] + trainer_host_mesh_name = cfg.services.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + + print("All services initialized successfully!") + + # Configure my dataset + tokenizer = get_tokenizer(cfg.student_model) + map_fn = lambda x: tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": x["question"] + + "\n\nPlease reason step by step, and put your final answer within \boxed{}.", + } + ], + add_generation_prompt=True, + tokenize=False, + ) + dataset = load_dataset(cfg.dataset.path, split=cfg.dataset.split).map(map_fn) + dataset_iter = iter(dataset) + + step = 0 + for epoch in range(max_steps): + start = time.perf_counter() + if step >= max_steps: + break + + trajectories = [] + while len(trajectories) < train_batch_size: + prompt = next(dataset_iter) + completions = await student_generator.generate.fanout(prompt) + for completion in itertools.chain(*completions): + trajectory = Trajectory( + pad_id=tokenizer.pad_token_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + completion=completion, + ) + input_ids = torch.cat( + [ + trajectory.request_tensor.unsqueeze(0), + trajectory.response_tensor.unsqueeze(0), + ], + dim=1, + ) + teacher_logprobs = await teacher.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + trajectory.teacher_logprobs = teacher_logprobs + trajectories.append(trajectory) + + trajectories = [ + trajectories[i::train_batch_size] for i in range(train_batch_size) + ] + inputs, targets = collate(trajectories) + await student_trainer.train_step.call(inputs, targets) + + await student_trainer.push_weights.call(step) + await student_generator.update_weights.fanout(step) + + step += 1 + + end = time.perf_counter() + print(f"Step {step} took {end - start} seconds") + await mlogger.flush.call_one(step) + + print(f"Training completed after {step} steps") + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() diff --git a/apps/on_policy_distillation/tests/test_loss.py b/apps/on_policy_distillation/tests/test_loss.py new file mode 100644 index 000000000..0976272b4 --- /dev/null +++ b/apps/on_policy_distillation/tests/test_loss.py @@ -0,0 +1,164 @@ +""" +Test file comparing reverse_kl_loss from the PR with Tinker/Thinking Machines implementation +PR: https://github.com/meta-pytorch/torchforge/pull/527 + +Citations from Tinker implementation: +- Blog post pseudocode: https://thinkingmachines.ai/blog/on-policy-distillation/ +- Tinker Cookbook: https://github.com/thinking-machines-lab/tinker-cookbook +""" + +import torch + +from apps.on_policy_distillation.main import reverse_kl_loss +from forge.util.ops import compute_logprobs + + +class TestReverseKLLoss: + """ + We want to cover a couple things in these tests: + 1. Basic input / output / handling of parameters + 2. Matches the Tinker implementation + 3. Behaving as expected meaning it pushes logprobs in the correct direction + """ + + def test_vs_tinker_loss(self): + """Test the complete pattern from Tinker's implementation.""" + batch_size, seq_len, vocab_size = 2, 5, 50 + + prompt = torch.randint(0, vocab_size, (batch_size, seq_len)) + response = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L71 + input_ids = torch.cat([prompt, response], dim=-1) + + teacher_logits = torch.full( + (batch_size, input_ids.size(1) + 1, vocab_size), -1000.0 + ) + for b in range(batch_size): + for t in range(input_ids.size(1)): + teacher_logits[b, t, response[b, t]] = 0.0 + + # https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L77 + teacher_logprobs = compute_logprobs(teacher_logits, response) + + student_logits = torch.full( + (batch_size, input_ids.size(1) + 1, vocab_size), -1000.0 + ) + for b in range(batch_size): + for t in range(input_ids.size(1)): + student_logits[b, t, response[b, t]] = 0.5 + + # https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L86 + student_logprobs = compute_logprobs(student_logits, response) + + # https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L87 + mask = response == 0 + mask = mask.float() + + # https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L89 + reverse_kl = (student_logprobs - teacher_logprobs) * mask + + # https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L100 + advantages = -1.0 * mask * reverse_kl + + pass + + def test_zero_kl_property(self): + """Test that KL is zero when distributions match perfectly.""" + batch_size, seq_len, vocab_size = 2, 5, 50 + + response = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # Create logits for seq_len+1 positions (to predict seq_len response tokens) + # compute_logprobs will slice logits[:, -seq_len-1:-1] to align with response + logits = torch.full((batch_size, seq_len + 1, vocab_size), -1000.0) + for b in range(batch_size): + for t in range(seq_len): + logits[b, t, response[b, t]] = 0.0 + + # Get student log probabilities for selected tokens using compute_logprobs + student_logprobs = compute_logprobs(logits, response) + + # Set teacher to match student exactly + teacher_logprobs = student_logprobs.clone().detach() + + # No padding + padding_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + + loss = reverse_kl_loss(logits, response, teacher_logprobs, padding_mask) + + # When student matches teacher, reverse_kl = 0, advantages = 0, loss = 0 + assert abs(loss.item()) < 1e-5, "Loss should be ~0 when student matches teacher" + + def test_loss_direction(self): + """Test that gradients push student logprobs toward teacher.""" + batch_size, seq_len, vocab_size = 1, 1, 10 # noqa + + # Single token case for clarity + response = torch.tensor([[5]]) # Token index 5 + + # Student has low probability for token 5 + # Need seq_len+1 positions for compute_logprobs alignment + logits = torch.full((1, 2, vocab_size), 0.0, requires_grad=True) + logits.data[0, 0, 5] = -3.0 # Low logit for token 5 + + # Teacher has higher probability (less negative logprob) + teacher_logprobs = torch.tensor([[-1.0]]) + + padding_mask = torch.ones(1, 1, dtype=torch.bool) + + # Compute loss and gradients + loss = reverse_kl_loss(logits, response, teacher_logprobs, padding_mask) + loss.backward() + + # When student logprob is lower than teacher, gradient should push it higher + # Gradient at index 5 should be negative (increase logit -> increase logprob) + assert logits.grad is not None + assert ( + logits.grad[0, 0, 5].item() < 0 + ), "Gradient should push logit higher when student < teacher" + + def test_mode_seeking_behavior(self): + """ + Test that reverse KL exhibits mode-seeking behavior. + + Citation: From blog post: + "reverse KL is 'mode seeking' — it learns one specific behavior + (the teacher's) instead of spreading its distribution across + several suboptimal options." + (https://thinkingmachines.ai/blog/on-policy-distillation/) + """ + batch_size, seq_len, vocab_size = 1, 3, 10 + + response = torch.tensor([[2, 5, 7]]) + + # Teacher has high confidence (low entropy) + teacher_logprobs = torch.tensor([[-0.1, -0.1, -0.1]]) + + # Student 1: Spread distribution (high entropy) + # Need seq_len+1 positions for compute_logprobs alignment + logits_spread = torch.zeros(batch_size, seq_len + 1, vocab_size) + + # Student 2: Focused distribution (low entropy, matching teacher's confidence) + logits_focused = torch.full((batch_size, seq_len + 1, vocab_size), -10.0) + logits_focused[0, 0, 2] = 10.0 + logits_focused[0, 1, 5] = 10.0 + logits_focused[0, 2, 7] = 10.0 + + padding_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + + # Compute losses + loss_spread = reverse_kl_loss( + logits_spread, response, teacher_logprobs, padding_mask + ) + loss_focused = reverse_kl_loss( + logits_focused, response, teacher_logprobs, padding_mask + ) + + # Mode-seeking: focused distribution should generally have different loss characteristics + assert isinstance(loss_spread.item(), float) + assert isinstance(loss_focused.item(), float) + + # Both losses should be finite + assert torch.isfinite(loss_spread) + assert torch.isfinite(loss_focused)