Skip to content

Commit 17ea9ab

Browse files
zpqiushuo-nvidia
andauthored
feat: add on policy distillation algorithm (#1006)
Signed-off-by: shuo_nvidia <shuoyang@nvidia.com> Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com> Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com> Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Signed-off-by: shuo-nvidia <shuoyang@nvidia.com> Co-authored-by: shuo_nvidia <shuoyang@nvidia.com>
1 parent b445a3a commit 17ea9ab

File tree

36 files changed

+4490
-15
lines changed

36 files changed

+4490
-15
lines changed

README.md

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
[![CICD NeMo RL](https://github.com/NVIDIA-NeMo/RL/actions/workflows/cicd-main.yml/badge.svg?branch=main&event=schedule)](https://github.com/NVIDIA-NeMo/RL/actions/workflows/cicd-main.yml)
44

55
## 📣 News
6+
* [9/25/2025] On-policy Distillation (Qwen3-style)
7+
* Student generates on-policy sequences and aligns logits to a larger teacher via KL, achieving near-larger-model quality at lower cost than RL. See [On-policy Distillation](#on-policy-distillation).
68
* [7/25/2025] [Release v0.3.0!](https://github.com/NVIDIA-NeMo/RL/releases/tag/v0.3.0)
79
* 📝 [v0.3.0 Blog Post](https://nvidia-nemo.github.io/blog/2025/07/21/nemo-rl-v0.3/)
810
* 📊 View the release run metrics on [Google Colab](https://colab.research.google.com/drive/15kpesCV1m_C5UQFStssTEjaN2RsBMeZ0?usp=sharing) to get a head start on your experimentation.
@@ -59,7 +61,7 @@ For detailed information on backend selection, configuration, and examples, see
5961
- 🔜 **Megatron Bridge Integration** - Integrate Megatron Bridge to enable training features from Megatron Core.
6062
- 🔜 **NeMo Automodel Integration** - Integrate NeMo Automodel to power our DTensor path.
6163
- 🔜 **New Models** - gpt-oss.
62-
- 🔜 **Expand Algorithms** - DAPO, GSPO.
64+
- 🔜 **Expand Algorithms** - DAPO, GSPO, On-policy Distillation.
6365
- 🔜 **GB200** - Add container support for GB200.
6466
-**Distributed Training** - Ray-based infrastructure.
6567
-**Environment Support and Isolation** - Support for multi-environment training and dependency isolation between components.
@@ -83,6 +85,7 @@ For detailed information on backend selection, configuration, and examples, see
8385
|Algorithms|Single Node|Multi-node|
8486
|-|-|-|
8587
|[GRPO](#grpo)|[GRPO Single Node](#grpo-single-node)|[GRPO Multi-node](#grpo-multi-node): [GRPO Qwen2.5-32B](#grpo-qwen25-32b), [GRPO Multi-Turn](#grpo-multi-turn)|
88+
|[On-policy Distillation](#on-policy-distillation)|[Distillation Single Node](#on-policy-distillation-single-node)|[Distillation Multi-node](#on-policy-distillation-multi-node)|
8689
|[Supervised Fine-Tuning (SFT)](#supervised-fine-tuning-sft)|[SFT Single Node](#sft-single-node)|[SFT Multi-node](#sft-multi-node)|
8790
|[DPO](#dpo)|[DPO Single Node](#dpo-single-node)|[DPO Multi-node](#dpo-multi-node)|
8891
|[RM](#rm)|[RM Single Node](#rm-single-node)|[RM Multi-node](#rm-multi-node)|
@@ -312,6 +315,49 @@ Reference example for training to play a Sliding Puzzle Game:
312315
uv run python examples/run_grpo_sliding_puzzle.py
313316
```
314317

318+
## On-policy Distillation
319+
320+
We provide an example on-policy distillation experiment using the [DeepScaler dataset](https://huggingface.co/agentica-org/DeepScaleR-1.5B-Preview).
321+
322+
> [!NOTE]
323+
> Distillation currently supports the DTensor and vLLM generation backend. Megatron generation/training paths are not supported yet.
324+
325+
### On-policy Distillation Single Node
326+
327+
To run on-policy distillation on a single GPU using `Qwen/Qwen3-1.7B-Base` as the student and `Qwen/Qwen3-4B` as the teacher:
328+
329+
```sh
330+
uv run python examples/run_distillation_math.py
331+
```
332+
333+
Customize parameters with command-line overrides. For example:
334+
335+
```sh
336+
uv run python examples/run_distillation_math.py \
337+
policy.model_name="Qwen/Qwen3-1.7B-Base" \
338+
teacher.model_name="Qwen/Qwen3-4B" \
339+
cluster.gpus_per_node=8
340+
```
341+
342+
### On-policy Distillation Multi-node
343+
344+
```sh
345+
# Run from the root of NeMo RL repo
346+
NUM_ACTOR_NODES=2
347+
348+
COMMAND="uv run ./examples/run_distillation_math.py --config examples/configs/distillation_math.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/distill_2nodes' logger.wandb_enabled=True logger.wandb.name='distill-2nodes'" \
349+
CONTAINER=YOUR_CONTAINER \
350+
MOUNTS="$PWD:$PWD" \
351+
sbatch \
352+
--nodes=${NUM_ACTOR_NODES} \
353+
--account=YOUR_ACCOUNT \
354+
--job-name=YOUR_JOBNAME \
355+
--partition=YOUR_PARTITION \
356+
--time=4:0:0 \
357+
--gres=gpu:8 \
358+
ray.sub
359+
```
360+
315361
## Supervised Fine-Tuning (SFT)
316362

317363
We provide example SFT experiments using various datasets including [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/), OpenAI format datasets (with tool calling support), and custom JSONL datasets. For detailed documentation on supported datasets and configurations, see the [SFT documentation](docs/guides/sft.md).
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Distillation Algorithm Configuration
2+
distillation:
3+
num_prompts_per_step: 128
4+
num_generations_per_prompt: 1
5+
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
6+
max_num_steps: 1000
7+
val_batch_size: 64
8+
val_period: 20
9+
val_at_start: false
10+
max_val_samples: 512
11+
topk_logits_k: 64
12+
seed: 42
13+
14+
loss_fn:
15+
kl_type: "mixed" # forward, reverse, mixed
16+
mixed_kl_weight: 0.5 # when kl_type is "mixed", this is the weight of the forward KL
17+
zero_outside_topk: false # zero out the teacher logits outside the top k when calculate forward KL loss
18+
19+
checkpointing:
20+
enabled: true
21+
checkpoint_dir: "checkpoints/distillation-${policy.model_name}"
22+
metric_name: "val_reward"
23+
higher_is_better: true
24+
keep_top_k: 3
25+
save_period: 10
26+
checkpoint_must_save_by: null
27+
model_save_format: "safetensors"
28+
save_consolidated: false
29+
30+
policy: &POLICY_BASE
31+
model_name: "Qwen/Qwen3-1.7B-Base"
32+
tokenizer:
33+
name: ${..model_name} ## specify if you'd like to use a tokenizer different from the model's default
34+
train_global_batch_size: 64
35+
train_micro_batch_size: 1
36+
generation_batch_size: 64
37+
logprob_batch_size: 1
38+
max_total_sequence_length: 8192
39+
precision: "bfloat16"
40+
logprob_chunk_size: null
41+
42+
dtensor_cfg: &DTENSOR_BASE
43+
enabled: true
44+
_v2: true
45+
cpu_offload: False
46+
sequence_parallel: false
47+
activation_checkpointing: true
48+
tensor_parallel_size: 2
49+
context_parallel_size: 2
50+
custom_parallel_plan: null
51+
52+
dynamic_batching:
53+
enabled: true
54+
train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}}
55+
logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}}
56+
sequence_length_round: 64
57+
58+
sequence_packing:
59+
enabled: false
60+
train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}}
61+
logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}}
62+
algorithm: "modified_first_fit_decreasing"
63+
sequence_length_round: 64
64+
65+
max_grad_norm: 1.0
66+
# makes the training sequence length divisible by the tensor parallel size
67+
# this is useful for sequence parallel training
68+
# must be divisible by 2*cp
69+
make_sequence_length_divisible_by: ${mul:${mul:${.dtensor_cfg.tensor_parallel_size}, ${.dtensor_cfg.context_parallel_size}}, 2}
70+
optimizer:
71+
name: "torch.optim.AdamW"
72+
kwargs:
73+
lr: 2.0e-5
74+
weight_decay: 0.01
75+
betas: [0.9, 0.999]
76+
eps: 1e-8
77+
# when using Dtensor, we need to set foreach
78+
# and fused to False
79+
foreach: False
80+
fused: False
81+
82+
megatron_cfg: # [TODO]
83+
enabled: false
84+
85+
scheduler:
86+
- name: "torch.optim.lr_scheduler.LinearLR"
87+
kwargs:
88+
start_factor: 0.1
89+
end_factor: 1.0
90+
total_iters: 10
91+
- name: "torch.optim.lr_scheduler.ConstantLR"
92+
kwargs:
93+
factor: 1.0
94+
total_iters: 10000000000
95+
- milestones: [10]
96+
97+
generation:
98+
backend: "vllm"
99+
max_new_tokens: ${..max_total_sequence_length} # refer to local policy/teacher config
100+
temperature: 1.0
101+
top_p: 1.0
102+
top_k: null
103+
stop_token_ids: null
104+
stop_strings: null
105+
vllm_cfg:
106+
async_engine: false
107+
precision: ${...precision}
108+
tensor_parallel_size: 1
109+
pipeline_parallel_size: 1
110+
expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP
111+
gpu_memory_utilization: 0.6
112+
max_model_len: ${...max_total_sequence_length} # refer to local policy/teacher config
113+
enforce_eager: False
114+
use_deep_gemm: False
115+
num_last_layers_in_bf16: 0
116+
num_first_layers_in_bf16: 0
117+
distributed_executor_backend: null
118+
119+
colocated:
120+
# true: generation shares training GPUs
121+
# false: uses dedicated generation resources
122+
enabled: true
123+
# only relevant when enabled is false
124+
resources:
125+
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
126+
num_nodes: null # Decides number of nodes to be dedicated to generation
127+
128+
129+
teacher:
130+
<<: *POLICY_BASE
131+
model_name: "Qwen/Qwen3-4B"
132+
dtensor_cfg:
133+
<<: *DTENSOR_BASE
134+
context_parallel_size: 2
135+
tensor_parallel_size: 4
136+
137+
data:
138+
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
139+
prompt_file: "examples/prompts/cot.txt"
140+
system_prompt_file: null
141+
dataset_name: "DeepScaler"
142+
shuffle: true
143+
144+
env:
145+
math:
146+
num_workers: 8
147+
148+
logger:
149+
log_dir: "logs/distillation"
150+
num_val_samples_to_print: 5
151+
wandb_enabled: true
152+
tensorboard_enabled: true
153+
mlflow_enabled: false
154+
swanlab_enabled: false
155+
monitor_gpus: true
156+
wandb:
157+
project: "nemo-distillation"
158+
name: "distillation-${data.dataset_name}-${teacher.model_name}-${policy.model_name}-${loss_fn.kl_type}-${distillation.topk_logits_k}"
159+
tensorboard:
160+
log_dir: "tb_logs-distillation-${data.dataset_name}"
161+
mlflow:
162+
experiment_name: "distillation-dev"
163+
run_name: "distillation-math-cl-logger"
164+
gpu_monitoring:
165+
collection_interval: 10
166+
flush_interval: 10
167+
168+
cluster:
169+
gpus_per_node: 8
170+
num_nodes: 1
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
defaults: ../../distillation_math.yaml
2+
distillation:
3+
num_prompts_per_step: 64
4+
max_num_steps: 20
5+
val_batch_size: 32
6+
val_period: 10
7+
max_val_samples: 256
8+
loss_fn:
9+
kl_type: reverse
10+
checkpointing:
11+
checkpoint_dir: checkpoints/distillation-qwen3-32b-to-1.7b-base
12+
policy:
13+
train_global_batch_size: 32
14+
generation_batch_size: 32
15+
dtensor_cfg:
16+
tensor_parallel_size: 1
17+
context_parallel_size: 1
18+
dynamic_batching:
19+
enabled: false
20+
make_sequence_length_divisible_by: 1
21+
scheduler:
22+
- name: torch.optim.lr_scheduler.LinearLR
23+
kwargs:
24+
start_factor: 0.1
25+
end_factor: 1.0
26+
total_iters: 20
27+
- name: torch.optim.lr_scheduler.ConstantLR
28+
kwargs:
29+
factor: 1.0
30+
total_iters: 10000000000
31+
- milestones:
32+
- 20
33+
teacher:
34+
model_name: Qwen/Qwen3-32B
35+
train_global_batch_size: 32
36+
generation_batch_size: 32
37+
dtensor_cfg:
38+
context_parallel_size: 1
39+
dynamic_batching:
40+
enabled: false
41+
make_sequence_length_divisible_by: 1
42+
scheduler:
43+
- name: torch.optim.lr_scheduler.LinearLR
44+
kwargs:
45+
start_factor: 0.1
46+
end_factor: 1.0
47+
total_iters: 20
48+
- name: torch.optim.lr_scheduler.ConstantLR
49+
kwargs:
50+
factor: 1.0
51+
total_iters: 10000000000
52+
- milestones:
53+
- 20
54+
logger:
55+
log_dir: logs/distillation-qwen3-32b-to-1.7b-base
56+
wandb:
57+
project: nemo-rl
58+
name: distillation-qwen3-32b-to-1.7b-base
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
defaults: ../../distillation_math.yaml
2+
distillation:
3+
num_prompts_per_step: 64
4+
max_num_steps: 20
5+
val_batch_size: 32
6+
val_period: 10
7+
max_val_samples: 256
8+
loss_fn:
9+
kl_type: reverse
10+
checkpointing:
11+
checkpoint_dir: checkpoints/distillation-qwen3-32b-to-4b-base-dynamicbatch
12+
policy:
13+
model_name: Qwen/Qwen3-4B-Base
14+
train_global_batch_size: 32
15+
generation_batch_size: 32
16+
dtensor_cfg:
17+
context_parallel_size: 1
18+
make_sequence_length_divisible_by: 2
19+
scheduler:
20+
- name: torch.optim.lr_scheduler.LinearLR
21+
kwargs:
22+
start_factor: 0.1
23+
end_factor: 1.0
24+
total_iters: 20
25+
- name: torch.optim.lr_scheduler.ConstantLR
26+
kwargs:
27+
factor: 1.0
28+
total_iters: 10000000000
29+
- milestones:
30+
- 20
31+
teacher:
32+
model_name: Qwen/Qwen3-32B
33+
train_global_batch_size: 32
34+
generation_batch_size: 32
35+
dtensor_cfg:
36+
tensor_parallel_size: 8
37+
context_parallel_size: 1
38+
make_sequence_length_divisible_by: 2
39+
scheduler:
40+
- name: torch.optim.lr_scheduler.LinearLR
41+
kwargs:
42+
start_factor: 0.1
43+
end_factor: 1.0
44+
total_iters: 20
45+
- name: torch.optim.lr_scheduler.ConstantLR
46+
kwargs:
47+
factor: 1.0
48+
total_iters: 10000000000
49+
- milestones:
50+
- 20
51+
logger:
52+
log_dir: logs/distillation-qwen3-32b-to-4b-base-dynamicbatch
53+
wandb:
54+
project: nemo-rl
55+
name: distillation-qwen3-32b-to-4b-base-dynamicbatch

0 commit comments

Comments
 (0)