Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/sphinx_doc/assets/opd_acc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/opd_kl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions examples/opd_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Example: On-Policy Distillation on GSM8K dataset

This example demonstrates On-Policy Distillation (OPD) algorithm training on the GSM8K dataset.

On-Policy Distillation is a knowledge distillation method, where in this example:
1. **Student model** (`Qwen/Qwen2.5-1.5B-Instruct`) generates trajectories with logprobs
2. **Teacher model** (`Qwen/Qwen2.5-Math-7B-Instruct`) computes logprobs on the same trajectories
3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
4. The student model is trained to minimize this KL divergence, effectively learning from the teacher

## Key Configuration

- **Algorithm**: `on_policy_distill`
- **Workflow**: `on_policy_distill_workflow`
- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct`
- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model)

## Running the Example

Download the model checkpoint and modify your config file, then run:
```bash
trinity run examples/opd_gsm8k/opd_gsm8k.yaml
```

Then you are all set! It should be pretty simple😄, and the training should converge very quick(Much quicker then RL).



![](../../docs/sphinx_doc/assets/opd_acc.png)
![](../../docs/sphinx_doc/assets/opd_kl.png)


## References

- https://arxiv.org/pdf/2306.13649
- https://thinkingmachines.ai/blog/on-policy-distillation/
74 changes: 74 additions & 0 deletions examples/opd_gsm8k/opd_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
project: "Trinity-RFT-gsm8k-opd"
name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: on_policy_distill
repeat_times: 8
optimizer:
lr: 1e-5
advantage_fn_args:
kl_coef: 1.0
model:
# Student model
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 1024
max_model_len: 2048
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: train
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
# Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward
default_workflow_type: 'on_policy_distill_math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_opd_buffer
storage_type: queue
explorer:
eval_interval: 50
runner_per_model: 8
rollout_model:
# Student model for rollout
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
auxiliary_models:
# Teacher model for distillation
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct}
engine_num: 1
tensor_parallel_size: 2
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
save_interval: 100
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
monitor:
monitor_type: wandb
4 changes: 4 additions & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from trinity.algorithm.advantage_fn.multi_step_grpo_advantage import (
StepWiseGRPOAdvantageFn,
)
from trinity.algorithm.advantage_fn.on_policy_distill_advantage import (
OnPolicyDistillAdvantage,
)
from trinity.algorithm.advantage_fn.opmd_advantage import (
OPMDAdvantageFn,
OPMDGroupAdvantage,
Expand Down Expand Up @@ -40,4 +43,5 @@
"REINFORCEGroupAdvantage",
"ASYMREAdvantageFn",
"RECGroupedAdvantage",
"OnPolicyDistillAdvantage",
]
69 changes: 69 additions & 0 deletions trinity/algorithm/advantage_fn/on_policy_distill_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""On-Policy Distillation advantage computation.
Reference: Tinker library's on-policy distillation.
advantages = -(student_logprobs - teacher_logprobs)
= teacher_logprobs - student_logprobs
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn


@ADVANTAGE_FN.register_module("on_policy_distill")
class OnPolicyDistillAdvantage(AdvantageFn):
"""Advantage function for on-policy distillation.
Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs)
The teacher_logprobs should be stored in Experience.teacher_logprobs
by the workflow during exploration.
"""

def __init__(self, kl_coef: float = 1.0) -> None:
self.kl_coef = kl_coef

def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
"""Compute advantages from teacher and student logprobs.
Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_log_probs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]
Returns:
exps: DataProto with advantages and returns added
metrics: Dict with kl and advantage statistics
"""
metrics = {}

old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs
teacher_log_probs = exps.batch["teacher_log_probs"]
response_mask = exps.batch["response_mask"]

# advantages = -(student - teacher) = teacher - student
advantages = self.kl_coef * (teacher_log_probs - old_log_probs)

# Apply mask
advantages = advantages * response_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = advantages.clone()

# Metrics
kl_per_token = old_log_probs - teacher_log_probs
kl_sum = (kl_per_token * response_mask).sum(dim=-1)
metrics["kl/mean"] = kl_sum.mean().item()
metrics["kl/std"] = kl_sum.std().item() if kl_sum.numel() > 1 else 0.0
metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item()

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {"kl_coef": 1.0}
33 changes: 33 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,36 @@ def default_config(cls) -> Dict:
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


@ALGORITHM_TYPE.register_module("on_policy_distill")
class OnPolicyDistillAlgorithm(AlgorithmType):
"""On-Policy Distillation Algorithm.

Reference: Tinker library.

Workflow stores teacher_logprobs in experience.info["teacher_logprobs"].
Trainer's advantage_fn computes: advantages = teacher_logprobs - student_logprobs
Trainer uses:
importance_sampling loss if no clipping is needed
ppo loss if clipping is needed, for better stability
"""

use_critic: bool = False
use_reference: bool = False
compute_advantage_in_trainer: bool = True # advantage_fn computes from teacher_logprobs
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 8,
"advantage_fn": "on_policy_distill",
"advantage_fn_args": {"kl_coef": 1.0},
"sample_strategy": "default",
"policy_loss_fn": "ppo", # or importance_sampling if no clipping is needed
"kl_penalty_fn": "none",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}
4 changes: 4 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn
from trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss import (
ImportanceSamplingLossFn,
)
from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
Expand All @@ -32,5 +35,6 @@
"SFTPhiLossFn",
"sPPOPolicyLossFn",
"RECPolicyLossFn",
"ImportanceSamplingLossFn",
"SAPOPolicyLossFn",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""The most simple Importance Sampling policy loss.

loss = -(prob_ratio * advantages).sum()
where prob_ratio = exp(current_logprobs - sampling_logprobs)

Note: This loss is used for on-policy distillation.
"""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("importance_sampling")
class ImportanceSamplingLossFn(PolicyLossFn):
"""Pure importance sampling loss without clipping.

loss = -(ratio * advantages)
where ratio = exp(logprob - old_logprob)
"""

def __init__(
self,
backend: str = "verl",
loss_agg_mode: str = "token-mean",
) -> None:
super().__init__(backend=backend)
self.loss_agg_mode = loss_agg_mode

def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
# prob_ratio = exp(current_logprobs - sampling_logprobs)
log_ratio = logprob - old_logprob
log_ratio = torch.clamp(log_ratio, min=-20.0, max=20.0)
ratio = torch.exp(log_ratio)

# loss = -(prob_ratio * advantages)
pg_losses = -advantages * ratio
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)

metrics = {
"pg_loss": pg_loss.detach().item(),
"ratio/mean": masked_mean(ratio, action_mask).detach().item(),
"approx_kl": masked_mean(-log_ratio, action_mask).detach().item(),
}

return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {"loss_agg_mode": "token-mean"}
2 changes: 1 addition & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class FormatConfig:
class GenerationConfig:
temperature: Optional[float] = None # 1.0
top_p: Optional[float] = None # 1.0
top_k: Optional[int] = None # -1
top_k: int = -1 # -1 means disabled
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
# repeat each task for `n` times
Expand Down
21 changes: 21 additions & 0 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class Experience:
# for multi-modal data
multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer

# for on-policy distillation
teacher_logprobs: Optional[Tensor] = None # [resp_length]

def __init__( # noqa: C901
self,
*,
Expand All @@ -157,6 +160,7 @@ def __init__( # noqa: C901
chosen_messages=None,
rejected_messages=None,
multi_modal_inputs=None,
teacher_logprobs=None,
):
if action_mask is not None:
experience_type = "multi_turn"
Expand Down Expand Up @@ -229,6 +233,11 @@ def __init__( # noqa: C901
else:
self.multi_modal_inputs[key] = value

# Handle teacher_logprobs
if isinstance(teacher_logprobs, list):
teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32)
self.teacher_logprobs = teacher_logprobs

if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
Expand All @@ -239,6 +248,8 @@ def __init__( # noqa: C901
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)

def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
Expand Down Expand Up @@ -341,6 +352,14 @@ def gather(
else:
multi_modal_inputs = None

# gather teacher_logprobs
if all(exp.teacher_logprobs is not None for exp in experiences):
teacher_logprobs = gather_response_attrs(
experiences, "teacher_logprobs", max_response_length
)
else:
teacher_logprobs = None

exps = Experiences(
eids=eids,
tokens=tokens,
Expand All @@ -353,6 +372,7 @@ def gather(
prompt_length=max_prompt_length,
logprobs=logprobs,
multi_modal_inputs=multi_modal_inputs,
teacher_logprobs=teacher_logprobs,
)
if custom_fields is not None:
for custom_field in custom_fields:
Expand Down Expand Up @@ -442,6 +462,7 @@ class Experiences:
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length]

@property
def batch_size(self) -> int:
Expand Down
Loading