Skip to content
Binary file added docs/sphinx_doc/assets/opd_acc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/opd_kl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions examples/opd_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Example: On-Policy Distillation on GSM8K dataset

This example demonstrates On-Policy Distillation (OPD) algorithm training on the GSM8K dataset.

On-Policy Distillation is a knowledge distillation method, where in this example:
1. **Student model** (`Qwen/Qwen2.5-1.5B-Instruct`) generates trajectories with logprobs
2. **Teacher model** (`Qwen/Qwen2.5-Math-7B-Instruct`) computes logprobs on the same trajectories
3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
4. The student model is trained to minimize this KL divergence, effectively learning from the teacher

## Key Configuration

- **Algorithm**: `on_policy_distill`
- **Workflow**: `on_policy_distill_workflow`
- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct`
- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model)

## Running the Example

Download the model checkpoint and modify your config file, then run:
```bash
trinity run examples/opd_gsm8k/opd_gsm8k.yaml
```

Then you are all set! It should be pretty simple😄, and the training should converge very quick(Much quicker then RL).



![](../../docs/sphinx_doc/assets/opd_acc.png)
![](../../docs/sphinx_doc/assets/opd_kl.png)


## References

- https://arxiv.org/pdf/2306.13649
- https://thinkingmachines.ai/blog/on-policy-distillation/
74 changes: 74 additions & 0 deletions examples/opd_gsm8k/opd_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
project: "Trinity-RFT-gsm8k-opd"
name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: on_policy_distill
repeat_times: 8
optimizer:
lr: 1e-5
advantage_fn_args:
kl_coef: 1.0
model:
# Student model
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 1024
max_model_len: 2048
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: train
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
# Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward
default_workflow_type: 'on_policy_distill_math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_opd_buffer
storage_type: queue
explorer:
eval_interval: 50
runner_per_model: 8
rollout_model:
# Student model for rollout
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
auxiliary_models:
# Teacher model for distillation
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct}
engine_num: 1
tensor_parallel_size: 2
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
save_interval: 100
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
monitor:
monitor_type: wandb
4 changes: 4 additions & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from trinity.algorithm.advantage_fn.multi_step_grpo_advantage import (
StepWiseGRPOAdvantageFn,
)
from trinity.algorithm.advantage_fn.on_policy_distill_advantage import (
OnPolicyDistillAdvantage,
)
from trinity.algorithm.advantage_fn.opmd_advantage import (
OPMDAdvantageFn,
OPMDGroupAdvantage,
Expand Down Expand Up @@ -40,4 +43,5 @@
"REINFORCEGroupAdvantage",
"ASYMREAdvantageFn",
"RECGroupedAdvantage",
"OnPolicyDistillAdvantage",
]
69 changes: 69 additions & 0 deletions trinity/algorithm/advantage_fn/on_policy_distill_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""On-Policy Distillation advantage computation.
Reference: Tinker library's on-policy distillation.
advantages = -(student_logprobs - teacher_logprobs)
= teacher_logprobs - student_logprobs
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn


@ADVANTAGE_FN.register_module("on_policy_distill")
class OnPolicyDistillAdvantage(AdvantageFn):
"""Advantage function for on-policy distillation.
Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs)
The teacher_logprobs should be stored in Experience.teacher_logprobs
by the workflow during exploration.
"""

def __init__(self, kl_coef: float = 1.0) -> None:
self.kl_coef = kl_coef

def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
"""Compute advantages from teacher and student logprobs.
Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_log_probs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]
Returns:
exps: DataProto with advantages and returns added
metrics: Dict with kl and advantage statistics
"""
metrics = {}

old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs
teacher_log_probs = exps.batch["teacher_log_probs"]
response_mask = exps.batch["response_mask"]

# advantages = -(student - teacher) = teacher - student
advantages = self.kl_coef * (teacher_log_probs - old_log_probs)

# Apply mask
advantages = advantages * response_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = advantages.clone()

# Metrics
kl_per_token = old_log_probs - teacher_log_probs
kl_sum = (kl_per_token * response_mask).sum(dim=-1)
metrics["kl/mean"] = kl_sum.mean().item()
metrics["kl/std"] = kl_sum.std().item() if kl_sum.numel() > 1 else 0.0
metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item()

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {"kl_coef": 1.0}
33 changes: 33 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,36 @@ def default_config(cls) -> Dict:
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


@ALGORITHM_TYPE.register_module("on_policy_distill")
class OnPolicyDistillAlgorithm(AlgorithmType):
"""On-Policy Distillation Algorithm.

Reference: Tinker library.

Workflow stores teacher_logprobs in experience.info["teacher_logprobs"].
Trainer's advantage_fn computes: advantages = teacher_logprobs - student_logprobs
Trainer uses:
importance_sampling loss if no clipping is needed
ppo loss if clipping is needed, for better stability
"""

use_critic: bool = False
use_reference: bool = False
compute_advantage_in_trainer: bool = True # advantage_fn computes from teacher_logprobs
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 8,
"advantage_fn": "on_policy_distill",
"advantage_fn_args": {"kl_coef": 1.0},
"sample_strategy": "default",
"policy_loss_fn": "ppo", # or importance_sampling if no clipping is needed
"kl_penalty_fn": "none",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}
4 changes: 4 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn
from trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss import (
ImportanceSamplingLossFn,
)
from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
Expand All @@ -32,5 +35,6 @@
"SFTPhiLossFn",
"sPPOPolicyLossFn",
"RECPolicyLossFn",
"ImportanceSamplingLossFn",
"SAPOPolicyLossFn",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""The most simple Importance Sampling policy loss.

loss = -(prob_ratio * advantages).sum()
where prob_ratio = exp(current_logprobs - sampling_logprobs)

Note: This loss is used for on-policy distillation.
"""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("importance_sampling")
class ImportanceSamplingLossFn(PolicyLossFn):
"""Pure importance sampling loss without clipping.

loss = -(ratio * advantages)
where ratio = exp(logprob - old_logprob)
"""

def __init__(
self,
backend: str = "verl",
loss_agg_mode: str = "token-mean",
) -> None:
super().__init__(backend=backend)
self.loss_agg_mode = loss_agg_mode

def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
# prob_ratio = exp(current_logprobs - sampling_logprobs)
log_ratio = logprob - old_logprob
log_ratio = torch.clamp(log_ratio, min=-20.0, max=20.0)
ratio = torch.exp(log_ratio)

# loss = -(prob_ratio * advantages)
pg_losses = -advantages * ratio
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)

metrics = {
"pg_loss": pg_loss.detach().item(),
"ratio/mean": masked_mean(ratio, action_mask).detach().item(),
"approx_kl": masked_mean(-log_ratio, action_mask).detach().item(),
}

return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {"loss_agg_mode": "token-mean"}
2 changes: 1 addition & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class FormatConfig:
class GenerationConfig:
temperature: Optional[float] = None # 1.0
top_p: Optional[float] = None # 1.0
top_k: Optional[int] = None # -1
top_k: int = -1 # -1 means disabled
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
# repeat each task for `n` times
Expand Down
21 changes: 21 additions & 0 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class Experience:
# for multi-modal data
multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer

# for on-policy distillation
teacher_logprobs: Optional[Tensor] = None # [resp_length]

def __init__( # noqa: C901
self,
*,
Expand All @@ -157,6 +160,7 @@ def __init__( # noqa: C901
chosen_messages=None,
rejected_messages=None,
multi_modal_inputs=None,
teacher_logprobs=None,
):
if action_mask is not None:
experience_type = "multi_turn"
Expand Down Expand Up @@ -229,6 +233,11 @@ def __init__( # noqa: C901
else:
self.multi_modal_inputs[key] = value

# Handle teacher_logprobs
if isinstance(teacher_logprobs, list):
teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32)
self.teacher_logprobs = teacher_logprobs

if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
Expand All @@ -239,6 +248,8 @@ def __init__( # noqa: C901
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)

def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
Expand Down Expand Up @@ -341,6 +352,14 @@ def gather(
else:
multi_modal_inputs = None

# gather teacher_logprobs
if all(exp.teacher_logprobs is not None for exp in experiences):
teacher_logprobs = gather_response_attrs(
experiences, "teacher_logprobs", max_response_length
)
else:
teacher_logprobs = None

exps = Experiences(
eids=eids,
tokens=tokens,
Expand All @@ -353,6 +372,7 @@ def gather(
prompt_length=max_prompt_length,
logprobs=logprobs,
multi_modal_inputs=multi_modal_inputs,
teacher_logprobs=teacher_logprobs,
)
if custom_fields is not None:
for custom_field in custom_fields:
Expand Down Expand Up @@ -442,6 +462,7 @@ class Experiences:
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length]

@property
def batch_size(self) -> int:
Expand Down
Loading