Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from trinity.algorithm.advantage_fn.multi_step_grpo_advantage import (
StepWiseGRPOAdvantageFn,
)
from trinity.algorithm.advantage_fn.on_policy_distill_advantage import (
OnPolicyDistillAdvantage,
)
from trinity.algorithm.advantage_fn.opmd_advantage import (
OPMDAdvantageFn,
OPMDGroupAdvantage,
Expand Down Expand Up @@ -40,4 +43,5 @@
"REINFORCEGroupAdvantage",
"ASYMREAdvantageFn",
"RECGroupedAdvantage",
"OnPolicyDistillAdvantage",
]
69 changes: 69 additions & 0 deletions trinity/algorithm/advantage_fn/on_policy_distill_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""On-Policy Distillation advantage computation.

Reference: Tinker library's on-policy distillation.

advantages = -(student_logprobs - teacher_logprobs)
= teacher_logprobs - student_logprobs
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn


@ADVANTAGE_FN.register_module("on_policy_distill")
class OnPolicyDistillAdvantage(AdvantageFn):
"""Advantage function for on-policy distillation.

Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs)

The teacher_logprobs should be stored in Experience.teacher_logprobs
by the workflow during exploration.
"""

def __init__(self, kl_coef: float = 1.0) -> None:
self.kl_coef = kl_coef

def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
"""Compute advantages from teacher and student logprobs.

Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_log_probs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]

Returns:
exps: DataProto with advantages and returns added
metrics: Dict with kl and advantage statistics
"""
metrics = {}

old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs
teacher_log_probs = exps.batch["teacher_log_probs"]
response_mask = exps.batch["response_mask"]

# advantages = -(student - teacher) = teacher - student
advantages = self.kl_coef * (teacher_log_probs - old_log_probs)

# Apply mask
advantages = advantages * response_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = advantages.clone()

# Metrics
kl_per_token = old_log_probs - teacher_log_probs
kl_sum = (kl_per_token * response_mask).sum(dim=-1)
metrics["kl/mean"] = kl_sum.mean().item()
metrics["kl/std"] = kl_sum.std().item() if kl_sum.numel() > 1 else 0.0
metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item()

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {"kl_coef": 1.0}
31 changes: 31 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,34 @@ def default_config(cls) -> Dict:
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}


@ALGORITHM_TYPE.register_module("on_policy_distill")
class OnPolicyDistillAlgorithm(AlgorithmType):
"""On-Policy Distillation Algorithm.

Reference: Tinker library.

Workflow stores teacher_logprobs in experience.info["teacher_logprobs"].
Trainer's advantage_fn computes: advantages = teacher_logprobs - student_logprobs
Trainer uses: importance_sampling loss (no clip)
"""

use_critic: bool = False
use_reference: bool = False
compute_advantage_in_trainer: bool = True # advantage_fn computes from teacher_logprobs
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 8,
"advantage_fn": "on_policy_distill",
"advantage_fn_args": {"kl_coef": 1.0},
"sample_strategy": "default",
"policy_loss_fn": "importance_sampling",
"kl_penalty_fn": "none",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}
4 changes: 4 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn
from trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss import (
ImportanceSamplingLossFn,
)
from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
Expand All @@ -31,4 +34,5 @@
"SFTPhiLossFn",
"sPPOPolicyLossFn",
"RECPolicyLossFn",
"ImportanceSamplingLossFn",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""The most simple Importance Sampling policy loss.

loss = -(prob_ratio * advantages).sum()
where prob_ratio = exp(current_logprobs - sampling_logprobs)

Note: This loss is used for on-policy distillation.
"""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import aggregate_loss, masked_mean


@POLICY_LOSS_FN.register_module("importance_sampling")
class ImportanceSamplingLossFn(PolicyLossFn):
"""Pure importance sampling loss without clipping.

loss = -(ratio * advantages)
where ratio = exp(logprob - old_logprob)
"""

def __init__(
self,
backend: str = "verl",
loss_agg_mode: str = "token-mean",
) -> None:
super().__init__(backend=backend)
self.loss_agg_mode = loss_agg_mode

def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
# prob_ratio = exp(current_logprobs - sampling_logprobs)
log_ratio = logprob - old_logprob
log_ratio = torch.clamp(log_ratio, min=-20.0, max=20.0)
ratio = torch.exp(log_ratio)

# loss = -(prob_ratio * advantages)
pg_losses = -advantages * ratio
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)

metrics = {
"pg_loss": pg_loss.detach().item(),
"ratio/mean": masked_mean(ratio, action_mask).detach().item(),
"approx_kl": masked_mean(-log_ratio, action_mask).detach().item(),
}

return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {"loss_agg_mode": "token-mean"}
21 changes: 21 additions & 0 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class Experience:
# for multi-modal data
multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer

# for on-policy distillation
teacher_logprobs: Optional[Tensor] = None # [resp_length]

def __init__( # noqa: C901
self,
*,
Expand All @@ -157,6 +160,7 @@ def __init__( # noqa: C901
chosen_messages=None,
rejected_messages=None,
multi_modal_inputs=None,
teacher_logprobs=None,
):
if action_mask is not None:
experience_type = "multi_turn"
Expand Down Expand Up @@ -229,6 +233,11 @@ def __init__( # noqa: C901
else:
self.multi_modal_inputs[key] = value

# Handle teacher_logprobs
if isinstance(teacher_logprobs, list):
teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32)
self.teacher_logprobs = teacher_logprobs

if not isinstance(self.tokens, Tensor):
self.tokens = torch.tensor(self.tokens)
if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
Expand All @@ -239,6 +248,8 @@ def __init__( # noqa: C901
self.chosen = torch.tensor(self.chosen)
if self.rejected is not None and not isinstance(self.rejected, Tensor):
self.rejected = torch.tensor(self.rejected)
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)

def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
Expand Down Expand Up @@ -341,6 +352,14 @@ def gather(
else:
multi_modal_inputs = None

# gather teacher_logprobs
if all(exp.teacher_logprobs is not None for exp in experiences):
teacher_logprobs = gather_response_attrs(
experiences, "teacher_logprobs", max_response_length
)
else:
teacher_logprobs = None

exps = Experiences(
eids=eids,
tokens=tokens,
Expand All @@ -353,6 +372,7 @@ def gather(
prompt_length=max_prompt_length,
logprobs=logprobs,
multi_modal_inputs=multi_modal_inputs,
teacher_logprobs=teacher_logprobs,
)
if custom_fields is not None:
for custom_field in custom_fields:
Expand Down Expand Up @@ -442,6 +462,7 @@ class Experiences:
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length]

@property
def batch_size(self) -> int:
Expand Down
7 changes: 7 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
from trinity.common.workflows.math_trainable_ruler_workflow import (
MathTrainableRULERWorkflow,
)
from trinity.common.workflows.on_policy_distill_workflow import (
AsyncOnPolicyDistillWorkflow,
OnPolicyDistillWorkflow,
)
from trinity.common.workflows.rubric_judge_workflow import RubricJudgeWorkflow
from trinity.common.workflows.simple_mm_workflow import (
AsyncSimpleMMWorkflow,
Expand Down Expand Up @@ -96,4 +100,7 @@
"RubricJudgeWorkflow",
"AgentScopeWorkflowAdapter",
"FrozenLakeWorkflow",
# On-policy distillation workflows
"OnPolicyDistillWorkflow",
"AsyncOnPolicyDistillWorkflow",
]
99 changes: 99 additions & 0 deletions trinity/common/workflows/on_policy_distill_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
"""On-Policy Distillation Workflow.

Reference: Tinker library's on-policy distillation implementation.

Algorithm:
1. Student samples trajectories (with logprobs)
2. Teacher computes logprobs on same trajectories
3. Store teacher_logprobs in experience.info["teacher_logprobs"]
4. Trainer's advantage_fn computes: advantages = teacher_logprobs - student_logprobs
5. Train with importance_sampling loss
"""

from typing import List, Optional

import openai
import torch

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, BaseSimpleWorkflow, Task


@WORKFLOWS.register_module("on_policy_distill_workflow")
class OnPolicyDistillWorkflow(BaseSimpleWorkflow):
"""On-policy distillation workflow.

Computes and stores teacher_logprobs in experience.info.
The advantage_fn in trainer will compute:
advantages = teacher_logprobs - student_logprobs
"""

is_async: bool = True
can_reset: bool = True
can_repeat: bool = True

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_model_wrappers: Optional[List[ModelWrapper]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.auxiliary_model_wrappers = auxiliary_model_wrappers

assert (
auxiliary_model_wrappers is not None and len(auxiliary_model_wrappers) >= 1
), "On-policy distillation requires at least one auxiliary model as teacher."
self.teacher_model = auxiliary_model_wrappers[0]

self.temperature = task.workflow_args.get("temperature", 1.0)

async def run_async(self) -> List[Experience]:
messages = self.format_messages()

# Step 1: Student samples trajectories
responses = await self.model.chat_async(messages, **self.rollout_args)

for i, response in enumerate(responses):
# Step 2: Teacher computes logprobs
teacher_logprobs = await self.teacher_model.logprobs_async(
tokens=response.tokens.tolist(),
temperature=self.temperature,
)

# Extract response portion
resp_start = response.prompt_length - 1
teacher_resp_logprobs = teacher_logprobs[resp_start:]
student_resp_logprobs = response.logprobs

# Match lengths
target_len = len(student_resp_logprobs)
if len(teacher_resp_logprobs) > target_len:
teacher_resp_logprobs = teacher_resp_logprobs[:target_len]
elif len(teacher_resp_logprobs) < target_len:
padding = torch.zeros(target_len - len(teacher_resp_logprobs))
teacher_resp_logprobs = torch.cat([teacher_resp_logprobs, padding])

# Step 3: Store teacher_logprobs for advantage_fn
response.teacher_logprobs = teacher_resp_logprobs

# Set a dummy reward (actual advantage computed by advantage_fn)
response.reward = 0.0
response.eid.run = i + self.run_id_base

# Metrics for monitoring
if response.metrics is None:
response.metrics = {}
kl = (student_resp_logprobs - teacher_resp_logprobs).sum().item()
response.metrics["kl_divergence"] = kl

return responses


@WORKFLOWS.register_module("async_on_policy_distill_workflow")
class AsyncOnPolicyDistillWorkflow(OnPolicyDistillWorkflow):
pass
Loading