Skip to content

Commit 2dd56e0

Browse files
committed
actual run and add on-policy distillation example in GSM8K
1 parent b4ec917 commit 2dd56e0

File tree

7 files changed

+230
-10
lines changed

7 files changed

+230
-10
lines changed

docs/sphinx_doc/assets/opd_acc.png

168 KB
Loading

docs/sphinx_doc/assets/opd_kl.png

169 KB
Loading

examples/opd_gsm8k/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Example: On-Policy Distillation on GSM8K dataset
2+
3+
This example demonstrates On-Policy Distillation (OPD) algorithm training on the GSM8K dataset.
4+
5+
On-Policy Distillation is a knowledge distillation method, where in this example:
6+
1. **Student model** (`Qwen/Qwen2.5-1.5B-Instruct`) generates trajectories with logprobs
7+
2. **Teacher model** (`Qwen/Qwen2.5-Math-7B-Instruct`) computes logprobs on the same trajectories
8+
3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
9+
4. The student model is trained to minimize this KL divergence, effectively learning from the teacher
10+
11+
## Key Configuration
12+
13+
- **Algorithm**: `on_policy_distill`
14+
- **Workflow**: `on_policy_distill_workflow`
15+
- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct`
16+
- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model)
17+
18+
## Running the Example
19+
20+
Download the model checkpoint and modify your config file, then run:
21+
```bash
22+
trinity run examples/opd_gsm8k/opd_gsm8k.yaml
23+
```
24+
25+
Then you are all set! It should be pretty simple😄, and the training should converge very quick(Much quicker then RL).
26+
27+
28+
29+
![](../../docs/sphinx_doc/assets/opd_acc.png)
30+
![](../../docs/sphinx_doc/assets/opd_kl.png)
31+
32+
33+
## References
34+
35+
- https://arxiv.org/pdf/2306.13649
36+
- https://thinkingmachines.ai/blog/on-policy-distillation/

examples/opd_gsm8k/opd_gsm8k.yaml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
project: "Trinity-RFT-gsm8k-opd"
2+
name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5"
3+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
4+
algorithm:
5+
algorithm_type: on_policy_distill
6+
repeat_times: 8
7+
optimizer:
8+
lr: 1e-5
9+
advantage_fn_args:
10+
kl_coef: 1.0
11+
model:
12+
# Student model
13+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
14+
max_response_tokens: 1024
15+
max_model_len: 2048
16+
cluster:
17+
node_num: 1
18+
gpu_per_node: 8
19+
buffer:
20+
total_epochs: 1
21+
batch_size: 96
22+
explorer_input:
23+
taskset:
24+
name: gsm8k
25+
storage_type: file
26+
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
27+
subset_name: main
28+
split: train
29+
format:
30+
prompt_key: 'question'
31+
response_key: 'answer'
32+
rollout_args:
33+
temperature: 1.0
34+
# Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward
35+
default_workflow_type: 'on_policy_distill_math_workflow'
36+
trainer_input:
37+
experience_buffer:
38+
name: gsm8k_opd_buffer
39+
storage_type: queue
40+
explorer:
41+
eval_interval: 50
42+
runner_per_model: 8
43+
rollout_model:
44+
# Student model for rollout
45+
engine_num: 4
46+
tensor_parallel_size: 1
47+
enable_prefix_caching: false
48+
enforce_eager: true
49+
dtype: bfloat16
50+
seed: 42
51+
auxiliary_models:
52+
# Teacher model for distillation
53+
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct}
54+
engine_num: 1
55+
tensor_parallel_size: 2
56+
enable_prefix_caching: false
57+
enforce_eager: true
58+
dtype: bfloat16
59+
seed: 42
60+
max_model_len: 4096
61+
max_prompt_tokens: 2048
62+
max_response_tokens: 1024
63+
synchronizer:
64+
sync_method: 'nccl'
65+
sync_interval: 1
66+
sync_timeout: 1200
67+
trainer:
68+
save_interval: 100
69+
grad_clip: 1.0
70+
use_dynamic_bsz: true
71+
max_token_len_per_gpu: 16384
72+
ulysses_sequence_parallel_size: 1
73+
monitor:
74+
monitor_type: wandb

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class FormatConfig:
8080
class GenerationConfig:
8181
temperature: Optional[float] = None # 1.0
8282
top_p: Optional[float] = None # 1.0
83-
top_k: Optional[int] = None # -1
83+
top_k: int = -1 # -1 means disabled
8484
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
8585
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
8686
# repeat each task for `n` times

trinity/common/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848
MathTrainableRULERWorkflow,
4949
)
5050
from trinity.common.workflows.on_policy_distill_workflow import (
51+
AsyncOnPolicyDistillMathWorkflow,
5152
AsyncOnPolicyDistillWorkflow,
53+
OnPolicyDistillMathWorkflow,
5254
OnPolicyDistillWorkflow,
5355
)
5456
from trinity.common.workflows.rubric_judge_workflow import RubricJudgeWorkflow
@@ -103,4 +105,6 @@
103105
# On-policy distillation workflows
104106
"OnPolicyDistillWorkflow",
105107
"AsyncOnPolicyDistillWorkflow",
108+
"OnPolicyDistillMathWorkflow",
109+
"AsyncOnPolicyDistillMathWorkflow",
106110
]

trinity/common/workflows/on_policy_distill_workflow.py

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,28 @@
1111
5. Train with importance_sampling loss
1212
"""
1313

14+
from dataclasses import asdict
1415
from typing import List, Optional
1516

1617
import openai
1718

1819
from trinity.common.experience import Experience
1920
from trinity.common.models.model import ModelWrapper
20-
from trinity.common.workflows.workflow import WORKFLOWS, BaseSimpleWorkflow, Task
21+
from trinity.common.rewards.qwen25_eval import verify_math_answer
22+
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
2123

2224

2325
@WORKFLOWS.register_module("on_policy_distill_workflow")
24-
class OnPolicyDistillWorkflow(BaseSimpleWorkflow):
26+
class OnPolicyDistillWorkflow(Workflow):
2527
"""On-policy distillation workflow.
2628
2729
Computes and stores teacher_logprobs in experience.info.
2830
The advantage_fn in trainer will compute:
2931
advantages = teacher_logprobs - student_logprobs
32+
33+
Note: This workflow does NOT use reward_fn because:
34+
- Advantage is computed from teacher-student logprobs difference
35+
- No external reward signal is needed
3036
"""
3137

3238
is_async: bool = True
@@ -41,8 +47,13 @@ def __init__(
4147
auxiliary_models: Optional[List[openai.OpenAI]] = None,
4248
auxiliary_model_wrappers: Optional[List[ModelWrapper]] = None,
4349
):
44-
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
45-
self.auxiliary_model_wrappers = auxiliary_model_wrappers
50+
super().__init__(
51+
task=task,
52+
model=model,
53+
auxiliary_models=auxiliary_models,
54+
auxiliary_model_wrappers=auxiliary_model_wrappers,
55+
)
56+
self.reset(task)
4657

4758
assert (
4859
auxiliary_model_wrappers is not None and len(auxiliary_model_wrappers) >= 1
@@ -51,6 +62,49 @@ def __init__(
5162

5263
self.temperature = task.workflow_args.get("temperature", 1.0)
5364

65+
def reset(self, task: Task):
66+
"""Reset the workflow with a new task.
67+
68+
Unlike BaseSimpleWorkflow, this does NOT require reward_fn.
69+
"""
70+
self.task = task
71+
self.format_args = task.format_args
72+
self.system_prompt = task.format_args.system_prompt
73+
self.reply_prefix = task.format_args.reply_prefix
74+
self.raw_task = task.raw_task
75+
self.task_desc = task.task_desc
76+
self.truth = task.truth
77+
78+
def set_repeat_times(self, repeat_times, run_id_base):
79+
self.repeat_times = repeat_times
80+
self.task.rollout_args.n = repeat_times
81+
self.run_id_base = run_id_base
82+
83+
@property
84+
def rollout_args(self):
85+
return asdict(self.task.rollout_args)
86+
87+
def format_messages(self):
88+
"""Format messages for the instruct model.
89+
90+
Default format: system_prompt (optional) + task_desc + reply_prefix (optional)
91+
"""
92+
messages = []
93+
if self.system_prompt:
94+
messages.append({"role": "system", "content": self.system_prompt})
95+
messages.append({"role": "user", "content": self.task_desc})
96+
if self.reply_prefix:
97+
messages.append({"role": "assistant", "content": self.reply_prefix})
98+
return messages
99+
100+
def compute_reward(self, response: Experience) -> float:
101+
"""Compute reward for a response.
102+
103+
In base class, returns 0.0 as advantage is computed from teacher-student logprobs.
104+
Subclasses can override this to compute actual rewards.
105+
"""
106+
return 0.0
107+
54108
async def run_async(self) -> List[Experience]:
55109
messages = self.format_messages()
56110

@@ -79,13 +133,16 @@ async def run_async(self) -> List[Experience]:
79133
# Step 3: Store teacher_logprobs for advantage_fn
80134
response.teacher_logprobs = teacher_resp_logprobs
81135

82-
# Set a dummy reward (actual advantage computed by advantage_fn)
83-
response.reward = 0.0
84-
response.eid.run = i + self.run_id_base
85-
86-
# Metrics for monitoring
136+
# Initialize metrics
87137
if response.metrics is None:
88138
response.metrics = {}
139+
140+
# Compute reward (subclasses can override compute_reward)
141+
response.reward = self.compute_reward(response)
142+
143+
response.eid.run = i + self.run_id_base
144+
145+
# KL divergence for monitoring
89146
kl = (student_resp_logprobs - teacher_resp_logprobs).sum().item()
90147
response.metrics["kl_divergence"] = kl
91148

@@ -94,4 +151,53 @@ async def run_async(self) -> List[Experience]:
94151

95152
@WORKFLOWS.register_module("async_on_policy_distill_workflow")
96153
class AsyncOnPolicyDistillWorkflow(OnPolicyDistillWorkflow):
154+
"""Alias for OnPolicyDistillWorkflow (already async)."""
155+
156+
pass
157+
158+
159+
@WORKFLOWS.register_module("on_policy_distill_math_workflow")
160+
class OnPolicyDistillMathWorkflow(OnPolicyDistillWorkflow):
161+
"""On-policy distillation workflow with Qwen2.5-Math style format.
162+
163+
This workflow:
164+
- Uses Qwen2.5-Math style prompt format (same as math_eval_workflow)
165+
- Computes accuracy using verify_math_answer as reward
166+
- Suitable for math reasoning tasks like GSM8K, MATH, etc.
167+
"""
168+
169+
def format_messages(self):
170+
"""Format messages using Qwen2.5-Math style.
171+
172+
System prompt: "You are a helpful assistant."
173+
User prompt: "{question}\nPlease reason step by step, and put your final answer within \\boxed{}."
174+
"""
175+
system_prompt = "You are a helpful assistant."
176+
user_prompt = f"{self.task_desc}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
177+
return [
178+
{"role": "system", "content": system_prompt},
179+
{"role": "user", "content": user_prompt},
180+
]
181+
182+
def compute_reward(self, response: Experience) -> float:
183+
"""Compute accuracy as reward using Qwen2.5-Math evaluation.
184+
185+
Returns 1.0 if answer is correct, 0.0 otherwise.
186+
"""
187+
if response.response_text and self.truth:
188+
accuracy, _ = verify_math_answer(
189+
response_text=response.response_text, ground_truth=self.truth
190+
)
191+
# Store accuracy in metrics
192+
if response.metrics is None:
193+
response.metrics = {}
194+
response.metrics["accuracy"] = accuracy
195+
return float(accuracy)
196+
return 0.0
197+
198+
199+
@WORKFLOWS.register_module("async_on_policy_distill_math_workflow")
200+
class AsyncOnPolicyDistillMathWorkflow(OnPolicyDistillMathWorkflow):
201+
"""Alias for OnPolicyDistillMathWorkflow (already async)."""
202+
97203
pass

0 commit comments

Comments
 (0)