Skip to content

Commit d2be0b7

Browse files
authored
Implement On-Policy Distillation (#444)
1 parent e412fbe commit d2be0b7

40 files changed

+551
-80
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
130130
| AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
131131
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
132132
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
133+
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
133134

134135

135136

@@ -142,7 +143,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
142143
- [Step 2: prepare dataset and model](#step-2-prepare-dataset-and-model)
143144
- [Step 3: configurations](#step-3-configurations)
144145
- [Step 4: run the RFT process](#step-4-run-the-rft-process)
145-
- [Contribution guide](#contribution-guide)
146+
- [Contribution Guide](#contribution-guide)
146147
- [Acknowledgements](#acknowledgements)
147148
- [Citation](#citation)
148149

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
129129
| AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
130130
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
131131
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
132+
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
132133

133134

134135

benchmark/reports/gsm8k.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class VerlGSM8kWorkflow(Workflow):
188188
*,
189189
task: Task,
190190
model: ModelWrapper,
191-
auxiliary_models: Optional[List[openai.OpenAI]] = None,
191+
auxiliary_models: Optional[List[ModelWrapper]] = None,
192192
):
193193
self.reset(task)
194194
super().__init__(

docs/sphinx_doc/assets/opd_acc.png

168 KB
Loading

docs/sphinx_doc/assets/opd_kl.png

169 KB
Loading

docs/sphinx_doc/source/main.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
8686
| AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
8787
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
8888
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
89+
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
90+
8991

9092

9193

docs/sphinx_doc/source/tutorial/develop_workflow.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,12 @@ class Workflow(ABC):
9393
*,
9494
task: Task,
9595
model: ModelWrapper,
96-
auxiliary_models: Optional[List[openai.OpenAI]] = None,
96+
auxiliary_models: Optional[List[ModelWrapper]] = None,
9797
):
9898
self.task = task
9999
self.model = model
100-
self.auxiliary_models = auxiliary_models
100+
self.auxiliary_model_wrappers = auxiliary_models
101+
self.auxiliary_models = ... # OpenAI clients auto-derived from ModelWrapper
101102

102103
@abstractmethod
103104
def run(self) -> List[Experience]:
@@ -110,7 +111,7 @@ During initialization, `Workflow` receives the following parameters:
110111

111112
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
112113
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
113-
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
114+
- `auxiliary_models`(`List[ModelWrapper]`): A list of auxiliary model wrappers. You can access OpenAI clients via `self.auxiliary_models` (auto-derived based on workflow's `is_async`).
114115

115116
```{tip}
116117
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
@@ -440,10 +441,10 @@ class MyWorkflow(Workflow):
440441
*,
441442
task: Task,
442443
model: ModelWrapper,
443-
auxiliary_models: Optional[List[openai.OpenAI]] = None,
444+
auxiliary_models: Optional[List[ModelWrapper]] = None,
444445
):
445446
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
446-
self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge
447+
self.judge_model = self.auxiliary_models[0] # OpenAI client auto-derived from ModelWrapper
447448
448449
def run(self) -> List[Experience]:
449450
response = self.do_something()

docs/sphinx_doc/source/tutorial/example_react.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class AgentScopeReActWorkflow(Workflow):
8282
*,
8383
task: Task,
8484
model: ModelWrapper,
85-
auxiliary_models: Optional[List[openai.OpenAI]] = None,
85+
auxiliary_models: Optional[List[ModelWrapper]] = None,
8686
):
8787
# initialize the agent
8888
self.agent = AgentScopeReActAgent(

docs/sphinx_doc/source_zh/main.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
8282
| AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
8383
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
8484
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
85+
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
8586

8687

8788

docs/sphinx_doc/source_zh/tutorial/develop_workflow.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@ class Workflow(ABC):
9292
*,
9393
task: Task,
9494
model: ModelWrapper,
95-
auxiliary_models: Optional[List[openai.OpenAI]] = None, # 主要用于 LLM-as-a-judge 场景,这里可以忽略
95+
auxiliary_models: Optional[List[ModelWrapper]] = None, # 主要用于 LLM-as-a-judge 场景, 也可以用作distillation的techer
9696
):
9797
self.task = task
9898
self.model = model
99-
self.auxiliary_models = auxiliary_models
99+
self.auxiliary_model_wrappers = auxiliary_models
100+
self.auxiliary_models = ... # 从 ModelWrapper 自动派生的 OpenAI client
100101

101102
@abstractmethod
102103
def run(self) -> List[Experience]:
@@ -109,7 +110,7 @@ class Workflow(ABC):
109110

110111
- `task`({class}`trinity.common.workflows.Task`):数据集中的单个任务。
111112
- `model`({class}`trinity.common.models.model.ModelWrapper`):正在训练的模型,提供类似于 OpenAI 的接口,能够接收对话消息列表并返回 LLM 生成的内容(包括回复文本 `response_text`、完整序列 token id `tokens`、prompt 部分 token 长度 `prompt_length`,以及输出 token 对数概率列表 `logprobs`)。
112-
- `auxiliary_models`(`List[openai.OpenAI]`):未参与训练的辅助模型列表。所有模型均通过兼容 OpenAI 的 API 提供,主要用于 LLM-as-a-judge 场景
113+
- `auxiliary_models`(`List[ModelWrapper]`):辅助模型的 ModelWrapper 列表。可通过 `self.auxiliary_models` 访问 OpenAI client(根据 workflow 的 `is_async` 自动派生)
113114

114115
以下是一个仅使用 `raw_task``rollout_args` 初始化简单工作流的示例。在更复杂的情况下,你可以使用 `format_args` 进行进一步自定义。
115116

@@ -437,10 +438,10 @@ class MyWorkflow(Workflow):
437438
*,
438439
task: Task,
439440
model: ModelWrapper,
440-
auxiliary_models: Optional[List[openai.OpenAI]] = None,
441+
auxiliary_models: Optional[List[ModelWrapper]] = None,
441442
):
442443
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
443-
self.judge_model = self.auxiliary_models[0] # 使用第一个辅助模型作为评判者
444+
self.judge_model = self.auxiliary_models[0] # 从 ModelWrapper 自动派生的 OpenAI client
444445
445446
def run(self) -> List[Experience]:
446447
response = self.do_something()

0 commit comments

Comments
 (0)