Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
| AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |



Expand All @@ -142,7 +143,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
- [Step 2: prepare dataset and model](#step-2-prepare-dataset-and-model)
- [Step 3: configurations](#step-3-configurations)
- [Step 4: run the RFT process](#step-4-run-the-rft-process)
- [Contribution guide](#contribution-guide)
- [Contribution Guide](#contribution-guide)
- [Acknowledgements](#acknowledgements)
- [Citation](#citation)

Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
| AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |



Expand Down
2 changes: 1 addition & 1 deletion benchmark/reports/gsm8k.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class VerlGSM8kWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
self.reset(task)
super().__init__(
Expand Down
Binary file added docs/sphinx_doc/assets/opd_acc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/opd_kl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
| AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |




Expand Down
11 changes: 6 additions & 5 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ class Workflow(ABC):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models
self.auxiliary_model_wrappers = auxiliary_models
self.auxiliary_models = ... # OpenAI clients auto-derived from ModelWrapper

@abstractmethod
def run(self) -> List[Experience]:
Expand All @@ -110,7 +111,7 @@ During initialization, `Workflow` receives the following parameters:

- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
- `auxiliary_models`(`List[ModelWrapper]`): A list of auxiliary model wrappers. You can access OpenAI clients via `self.auxiliary_models` (auto-derived based on workflow's `is_async`).

```{tip}
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
Expand Down Expand Up @@ -440,10 +441,10 @@ class MyWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge
self.judge_model = self.auxiliary_models[0] # OpenAI client auto-derived from ModelWrapper

def run(self) -> List[Experience]:
response = self.do_something()
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class AgentScopeReActWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
# initialize the agent
self.agent = AgentScopeReActAgent(
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/source_zh/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
| AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |



Expand Down
11 changes: 6 additions & 5 deletions docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ class Workflow(ABC):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None, # 主要用于 LLM-as-a-judge 场景,这里可以忽略
auxiliary_models: Optional[List[ModelWrapper]] = None, # 主要用于 LLM-as-a-judge 场景, 也可以用作distillation的techer
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models
self.auxiliary_model_wrappers = auxiliary_models
self.auxiliary_models = ... # 从 ModelWrapper 自动派生的 OpenAI client

@abstractmethod
def run(self) -> List[Experience]:
Expand All @@ -109,7 +110,7 @@ class Workflow(ABC):

- `task`({class}`trinity.common.workflows.Task`):数据集中的单个任务。
- `model`({class}`trinity.common.models.model.ModelWrapper`):正在训练的模型,提供类似于 OpenAI 的接口,能够接收对话消息列表并返回 LLM 生成的内容(包括回复文本 `response_text`、完整序列 token id `tokens`、prompt 部分 token 长度 `prompt_length`,以及输出 token 对数概率列表 `logprobs`)。
- `auxiliary_models`(`List[openai.OpenAI]`):未参与训练的辅助模型列表。所有模型均通过兼容 OpenAI 的 API 提供,主要用于 LLM-as-a-judge 场景
- `auxiliary_models`(`List[ModelWrapper]`):辅助模型的 ModelWrapper 列表。可通过 `self.auxiliary_models` 访问 OpenAI client(根据 workflow 的 `is_async` 自动派生)

以下是一个仅使用 `raw_task` 和 `rollout_args` 初始化简单工作流的示例。在更复杂的情况下,你可以使用 `format_args` 进行进一步自定义。

Expand Down Expand Up @@ -437,10 +438,10 @@ class MyWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # 使用第一个辅助模型作为评判者
self.judge_model = self.auxiliary_models[0] # 从 ModelWrapper 自动派生的 OpenAI client

def run(self) -> List[Experience]:
response = self.do_something()
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class AgentScopeReActWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
# initialize the agent
self.agent = AgentScopeReActAgent(
Expand Down
4 changes: 1 addition & 3 deletions examples/learn_to_ask/workflow/workflow_learn2ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import time
from typing import List, Optional

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import SimpleWorkflow, Task
Expand Down Expand Up @@ -36,7 +34,7 @@ def __init__(
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
self.train_mode = task.workflow_args.get("train_mode", "Ra+Rs")
self.fusion_mode = task.workflow_args.get("fusion_mode", "default")
Expand Down
36 changes: 36 additions & 0 deletions examples/opd_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Example: On-Policy Distillation on GSM8K dataset

This example demonstrates On-Policy Distillation (OPD) algorithm training on the GSM8K dataset.

On-Policy Distillation is a knowledge distillation method, where in this example:
1. **Student model** (`Qwen/Qwen2.5-1.5B-Instruct`) generates trajectories with logprobs
2. **Teacher model** (`Qwen/Qwen2.5-Math-7B-Instruct`) computes logprobs on the same trajectories
3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
4. The student model is trained to minimize this KL divergence, effectively learning from the teacher

## Key Configuration

- **Algorithm**: `on_policy_distill`
- **Workflow**: `on_policy_distill_workflow`
- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct`
- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model)

## Running the Example

Download the model checkpoint and modify your config file, then run:
```bash
trinity run examples/opd_gsm8k/opd_gsm8k.yaml
```

Then you are all set! It should be pretty simple😄, and the training should converge very quick.



![](../../docs/sphinx_doc/assets/opd_acc.png)
![](../../docs/sphinx_doc/assets/opd_kl.png)


## References

- https://arxiv.org/pdf/2306.13649
- https://thinkingmachines.ai/blog/on-policy-distillation/
74 changes: 74 additions & 0 deletions examples/opd_gsm8k/opd_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
project: "Trinity-RFT-gsm8k-opd"
name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: on_policy_distill
repeat_times: 8
optimizer:
lr: 1e-5
advantage_fn_args:
kl_coef: 1.0
model:
# Student model
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 1024
max_model_len: 2048
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: train
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
# Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward
default_workflow_type: 'on_policy_distill_math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_opd_buffer
storage_type: queue
explorer:
eval_interval: 50
runner_per_model: 8
rollout_model:
# Student model for rollout
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
auxiliary_models:
# Teacher model for distillation
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct}
engine_num: 1
tensor_parallel_size: 2
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
save_interval: 100
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
monitor:
monitor_type: wandb
14 changes: 8 additions & 6 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def __init__(self, model, task: Task, auxiliary_models=None):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]
self.repeat_times = task.rollout_args.n
if auxiliary_models is not None:
for model in auxiliary_models:
assert isinstance(model, openai.OpenAI)
# Check self.auxiliary_models (OpenAI clients derived from ModelWrapper)
if self.auxiliary_models is not None:
for m in self.auxiliary_models:
assert isinstance(m, openai.OpenAI)

def reset(self, task: Task):
self.obj = task.raw_task
Expand Down Expand Up @@ -92,9 +93,10 @@ def __init__(self, model, task: Task, auxiliary_models=None):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]
self.repeat_times = task.rollout_args.n
if auxiliary_models is not None:
for model in auxiliary_models:
assert isinstance(model, openai.AsyncOpenAI)
# Check self.auxiliary_models (AsyncOpenAI clients derived from ModelWrapper)
if self.auxiliary_models is not None:
for m in self.auxiliary_models:
assert isinstance(m, openai.AsyncOpenAI)

def reset(self, task: Task):
self.obj = task.raw_task
Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"sppo": "trinity.algorithm.algorithm.sPPOAlgorithm",
"rec": "trinity.algorithm.algorithm.RECAlgorithm",
"multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm",
"on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm",
},
)

Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"asymre": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREGroupAdvantage",
"asymre_verl": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREAdvantageFn",
"rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage",
"on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage",
},
)

Expand Down
68 changes: 68 additions & 0 deletions trinity/algorithm/advantage_fn/on_policy_distill_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""On-Policy Distillation advantage computation.

Reference: Tinker library's on-policy distillation.

advantages = -(student_logprobs - teacher_logprobs)
= teacher_logprobs - student_logprobs
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn


class OnPolicyDistillAdvantage(AdvantageFn):
"""Advantage function for on-policy distillation.

Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs)

The teacher_logprobs should be stored in Experience.teacher_logprobs
by the workflow during exploration.
"""

def __init__(self, kl_coef: float = 1.0) -> None:
self.kl_coef = kl_coef

def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
"""Compute advantages from teacher and student logprobs.

Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_log_probs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]

Returns:
exps: DataProto with advantages and returns added
metrics: Dict with kl and advantage statistics
"""
metrics = {}

old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs
teacher_log_probs = exps.batch["teacher_log_probs"]
response_mask = exps.batch["response_mask"]

# advantages = -(student - teacher) = teacher - student
advantages = self.kl_coef * (teacher_log_probs - old_log_probs)

# Apply mask
advantages = advantages * response_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = advantages.clone()

# Metrics
kl_per_token = old_log_probs - teacher_log_probs
kl_sum = (kl_per_token * response_mask).sum(dim=-1)
metrics["kl/mean"] = kl_sum.mean().item()
metrics["kl/std"] = kl_sum.std().item() if kl_sum.numel() > 1 else 0.0
metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item()

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {"kl_coef": 1.0}
Loading