Skip to content

Commit 45fde49

Browse files
authored
[feat] support megatron gkd (#7216)
1 parent b6eb9e8 commit 45fde49

18 files changed

Lines changed: 1572 additions & 329 deletions

File tree

docs/source/Instruction/GKD.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ elif seq_kd:
103103
y = teacher.generate(x)
104104
source = "teacher"
105105
else:
106-
# Mode 3: Off-Policy 学习,使用数据集中的输出序列
106+
# Mode 3: 使用数据集中的输出序列
107107
y = y_ground_truth
108108
source = "dataset"
109109

@@ -128,7 +128,7 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
128128

129129
**数据来源**:$y \sim P_{\text{teacher}}(\cdot | x)$
130130

131-
### Mode 3: Off-Policy 学习(其他情况)
131+
### Mode 3: 离线学习(其他情况)
132132

133133
**数据来源**:$y = y^* \sim \text{Dataset}$
134134

@@ -143,9 +143,10 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
143143
|------|------|--------|---------|------|
144144
| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID |
145145
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数<br>• 0.0: Forward KL <br>• 0.5: JSD (平衡)<br>• 1.0: Reverse KL |
146-
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 纯 Off-Policy<br>• 0.5: 混合策略<br>• 1.0: 纯 On-Policy |
146+
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 离线学习<br>• 0.5: 混合策略<br>• 1.0: 纯 On-Policy |
147147
| `--seq_kd` | bool | False | True/False | 是否使用教师生成序列<br>• False: 非 on-policy 时使用数据集<br>• True: 非 on-policy 时使用教师生成 |
148148
| `--temperature` | float | 0.9 | > 0 | 生成采样温度,控制随机性 |
149+
| `--sft_alpha` | float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 |
149150
| `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 |
150151

151152
## 采样加速

docs/source/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
326326

327327
## RLHF参数
328328
除了继承训练参数外,还支持以下参数:
329-
- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'grpo'、'kto''rm'。
329+
- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'grpo'、'kto''rm'和'gkd'。
330330
- loss_scale: 覆盖[基本参数](../Instruction/Command-line-parameters.md)中的loss_scale。默认为'last_round'。
331331
- calculate_per_token_loss: 覆盖Megatron参数,默认为False。
332332

@@ -406,6 +406,21 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
406406

407407
内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数)
408408

409+
### GKD参数
410+
- teacher_model: 教师模型的路径或模型 ID,必需参数。
411+
- teacher_model_type: 教师模型类型,默认为None,自动检测。
412+
- teacher_model_revision: 教师模型版本,默认为None。
413+
- beta: JSD 散度插值系数。0.0 代表 Forward KL,0.5 代表对称 JSD,1.0 代表 Reverse KL。默认为0.5。
414+
- lmbda: On-Policy 学习触发概率。0.0 代表纯 Off-Policy,1.0 代表纯 On-Policy。默认为0.5。
415+
- seq_kd: 是否使用教师生成的响应(Sequential KD),当前暂不支持。默认为False。
416+
- temperature: 用于采样和损失计算的温度参数。默认为0.9。
417+
- offload_teacher_model: 是否将教师模型卸载到 CPU 以节省 GPU 显存。默认为False。
418+
- sft_alpha: SFT 损失的混合系数,`loss = jsd_loss + sft_alpha * sft_loss`。当使用数据集响应(Off-Policy)时生效。默认为0。
419+
- max_completion_length: 生成时的最大 token 数。默认为512。
420+
- vllm_mode: 同 GRPO 参数,用于 On-Policy 生成。colocate 模式下在程序内部署 vLLM。
421+
- 注意:On-Policy 生成需要启用 vLLM(`--use_vllm true --vllm_mode colocate/server`)。
422+
-`lmbda > 0` 但未启用 vLLM 时,将自动回退到 Off-Policy 模式。
423+
409424
## 导出参数
410425
这里介绍`megatron export`的参数(需"ms-swift>=3.10"),若要使用`swift export`导出命令,请参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md#导出参数)`megatron export`相比`swift export`,支持分布式和多机导出。Megatron导出参数继承自Megatron参数和基本参数。
411426
- 🔥to_mcore: HF格式权重转成Megatron格式。默认为False。

docs/source/Megatron-SWIFT/GKD.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# GKD
2+
3+
**版本依赖**:ms-swift >= 3.12
4+
5+
如果你是首次使用 GKD,请先参考 [GKD文档](../Instruction/GKD.md)
6+
7+
GKD(Generalized Knowledge Distillation,广义知识蒸馏)是一种将教师模型的知识迁移到学生模型的训练方法,通过计算两个模型输出分布之间的 Jensen-Shannon 散度(JSD)损失来实现知识蒸馏。
8+
9+
## 功能支持
10+
11+
Megatron GKD 当前已支持以下功能:
12+
13+
- **训练模式**:全参数训练与 LoRA 微调
14+
- **并行策略**:支持上下文并行(CP)、流水线并行(PP)、张量并行(TP)和专家并行(EP)
15+
- **模型支持**:兼容 Megatron-SWIFT 中的 LLM 及 MLLM
16+
- **Teacher Offload**:支持将教师模型卸载到 CPU 以节省 GPU 显存
17+
- **在线生成**:支持使用 vLLM 进行学生模型的 on-policy 生成
18+
19+
### 当前限制
20+
21+
- **教师模型在线生成**`seq_kd=True`):当前 Sequential KD 模式下的教师模型生成暂不支持
22+
- **非vLLM生成**:On-policy 生成当前仅支持 vLLM
23+
- **教师模型使用与学生模型不同的并行参数**: 将在未来版本支持
24+
25+
⚠️ 注意事项:
26+
- **On-policy 生成**:需要启用 vLLM(`--use_vllm true --vllm_mode colocate/server`
27+
-`lmbda > 0` 但未启用 vLLM 时,将自动回退到离线学习模式(使用数据集响应)
28+
-`seq_kd=True` 时,由于教师生成暂不支持,将自动回退到离线学习模式,如需使用,请提前用[swift infer](../Instruction/Inference-and-deployment.md)推理数据集
29+
30+
## 参数说明
31+
32+
### GKD 特有参数
33+
34+
| 参数 | 类型 | 默认值 | 说明 |
35+
|------|------|--------|------|
36+
| `--teacher_model` | str | 必需 | 教师模型路径或模型 ID |
37+
| `--beta` | float | 0.5 | JSD 散度插值系数:<br>• 0.0: Forward KL<br>• 0.5: 对称 JSD<br>• 1.0: Reverse KL |
38+
| `--lmbda` | float | 0.5 | On-Policy 学习触发概率:<br>• 0.0: 纯 Off-Policy<br>• 1.0: 纯 On-Policy |
39+
| `--seq_kd` | bool | False | 是否使用教师生成的响应(当前暂不支持) |
40+
| `--temperature` | float | 0.9 | 温度参数,用于采样和损失计算 |
41+
| `--sft_alpha` | float | 0 | 混合一定比例的sft loss,对非student生成结果生效 |
42+
| `--max_completion_length` | int | 512 | 生成时的最大 token 数 |
43+
44+
### 批量相关参数
45+
46+
与 Megatron SFT 相同,使用以下参数控制批量大小:
47+
48+
| 参数 | 说明 |
49+
|------|------|
50+
| `--micro_batch_size` | 每张 GPU 的训练批次大小 |
51+
| `--global_batch_size` | 全局批次大小:`micro_batch_size × dp_size × gradient_accumulation_steps` |
52+
53+
## 三种训练模式
54+
55+
GKD 支持三种训练模式,通过 `lmbda``seq_kd` 参数控制:
56+
57+
### Mode 1: On-Policy 学习
58+
- 触发条件:`random() < lmbda``use_vllm=True`
59+
- 数据来源:学生模型生成的响应
60+
61+
### Mode 2: Sequential KD(当前暂不支持)
62+
- 触发条件:`random() >= lmbda``seq_kd=True`
63+
- 数据来源:教师模型生成的响应
64+
65+
### Mode 3: Off-Policy 学习
66+
- 触发条件:其他情况
67+
- 数据来源:数据集中的标注响应
68+
69+
## 参考
70+
71+
更多参数请参考[命令行文档](./Command-line-parameters.md)
72+
73+
训练脚本请参考 [Megatron GKD 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd)

docs/source/Megatron-SWIFT/Quick-start.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数
99
| 预训练 ||||||
1010
| [指令监督微调](https://github.com/modelscope/ms-swift/tree/main/examples/megatron) ||||||
1111
| [GRPO](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/grpo) ||||||
12+
| [GKD](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/gkd) ||||||
1213
| [DPO](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/dpo) ||||||
1314
| [KTO](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/kto) ||||||
1415
| [RM](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/rm) ||||||

docs/source_en/Instruction/GKD.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# GKD
22

3-
GKD (Generalized Knowledge Distillation) training algorithm is proposed in the paper [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://arxiv.org/pdf/2306.13649). This algorithm transfers knowledge from the teacher model to the student model by combining off-policy and on-policy learning strategies.
3+
GKD (Generalized Knowledge Distillation) training algorithm is proposed in the paper [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://arxiv.org/pdf/2306.13649). This algorithm transfers knowledge from the teacher model to the student model by combining offline and on-policy learning strategies.
44

55
## Loss Function
66

@@ -103,7 +103,7 @@ elif seq_kd:
103103
y = teacher.generate(x)
104104
source = "teacher"
105105
else:
106-
# Mode 3: Off-Policy learning, use output sequence from dataset
106+
# Mode 3: Offline learning, use output sequence from dataset
107107
y = y_ground_truth
108108
source = "dataset"
109109

@@ -128,7 +128,7 @@ Set parameter `seq_kd=True`, when on-policy is not triggered, use teacher model
128128

129129
**Data Source**: $y \sim P_{\text{teacher}}(\cdot | x)$
130130

131-
### Mode 3: Off-Policy Learning (other cases)
131+
### Mode 3: Offline Learning (other cases)
132132

133133
**Data Source**: $y = y^* \sim \text{Dataset}$
134134

@@ -143,9 +143,10 @@ We can perform GKD training by setting the following parameters:
143143
|------|------|--------|---------|------|
144144
| `--teacher_model` | str | Required | - | Teacher model path or model ID |
145145
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Forward KL <br>• 0.5: JSD (balanced)<br>• 1.0: Reverse KL |
146-
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability<br>• 0.0: Pure Off-Policy<br>• 0.5: Mixed strategy (**recommended**)<br>• 1.0: Pure On-Policy |
146+
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability<br>• 0.0: Pure Offline<br>• 0.5: Mixed strategy (**recommended**)<br>• 1.0: Pure On-Policy |
147147
| `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences<br>• False: Use dataset when not on-policy<br>• True: Use teacher generation when not on-policy |
148148
| `--temperature` | float | 0.9 | > 0 | Generation sampling temperature, controls randomness |
149+
| `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
149150
| `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation |
150151

151152
## Sampling Acceleration

docs/source_en/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa
347347

348348
In addition to inheriting the training parameters, the following parameters are also supported:
349349

350-
- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'grpo', 'kto', and 'rm' are available.
350+
- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'grpo', 'kto', 'rm', and 'gkd' are available.
351351
- loss_scale: Overrides the `loss_scale` in [basic parameters](../Instruction/Command-line-parameters.md). Default is 'last_round'.
352352
- calculate_per_token_loss: Overrides the Megatron parameter. Default is False.
353353

@@ -430,6 +430,22 @@ In addition to inheriting the training parameters, the following parameters are
430430

431431
Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters).
432432

433+
### GKD Parameters
434+
435+
- teacher_model: Path or model ID of the teacher model. Required.
436+
- teacher_model_type: Teacher model type. Default is None, auto-detected.
437+
- teacher_model_revision: Teacher model version. Default is None.
438+
- beta: JSD divergence interpolation coefficient. 0.0 means Forward KL, 0.5 means symmetric JSD, 1.0 means Reverse KL. Default is 0.5.
439+
- lmbda: On-Policy learning probability. 0.0 means pure Off-Policy, 1.0 means pure On-Policy. Default is 0.5.
440+
- seq_kd: Whether to use teacher-generated responses (Sequential KD), not yet supported. Default is False.
441+
- temperature: Temperature for sampling and loss computation. Default is 0.9.
442+
- offload_teacher_model: Whether to offload teacher model to CPU to save GPU memory. Default is False.
443+
- sft_alpha: Mixing coefficient for SFT loss, `loss = jsd_loss + sft_alpha * sft_loss`. Takes effect when using dataset responses (Off-Policy). Default is 0.
444+
- max_completion_length: Maximum tokens for generation. Default is 512.
445+
- vllm_mode: Same as GRPO parameter, used for On-Policy generation. Colocate mode deploys vLLM within the program.
446+
- Note: On-Policy generation requires vLLM (`--use_vllm true --vllm_mode colocate/server`).
447+
- When `lmbda > 0` but vLLM is not enabled, it will automatically fall back to Off-Policy mode.
448+
433449
## Export Parameters
434450

435451
This section introduces the parameters for `megatron export` (requires "ms-swift>=3.10"). To use the `swift export` command for exporting, please refer to the [ms-swift Command Line Parameters Documentation](../Instruction/Command-line-parameters.md#export-arguments). Compared to `swift export`, `megatron export` supports distributed and multi-node exporting. Megatron export parameters inherit from Megatron parameters and basic parameters.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# GKD
2+
3+
**Version Requirement**: ms-swift >= 3.12
4+
5+
If you are new to GKD, please refer to the [GKD Documentation](../Instruction/GKD.md) first.
6+
7+
GKD (Generalized Knowledge Distillation) is a training method that transfers knowledge from a teacher model to a student model by computing the Jensen-Shannon Divergence (JSD) loss between their output distributions.
8+
9+
## Feature Support
10+
11+
Megatron GKD currently supports the following features:
12+
13+
- **Training Modes**: Full parameter training and LoRA fine-tuning
14+
- **Parallelism Strategies**: Context Parallel (CP), Pipeline Parallel (PP), Tensor Parallel (TP), and Expert Parallel (EP)
15+
- **Model Support**: Compatible with LLMs and MLLMs in Megatron-SWIFT
16+
- **Teacher Offload**: Supports offloading teacher model to CPU to save GPU memory
17+
- **Online Generation**: Supports on-policy generation using vLLM for student model
18+
19+
### Current Limitations
20+
21+
- **Teacher Model Online Generation** (`seq_kd=True`): Teacher model generation in Sequential KD mode is not yet supported
22+
- **Non-vLLM Generation**: On-policy generation currently only supports vLLM
23+
- **Teacher model with different parallel parameters**: Will be supported in future versions
24+
25+
⚠️ Notes:
26+
- **On-policy Generation**: Requires vLLM (`--use_vllm true --vllm_mode colocate/server`)
27+
- When `lmbda > 0` but vLLM is not enabled, it will automatically fall back to off-policy mode (using dataset responses)
28+
- When `seq_kd=True`, since teacher generation is not yet supported, it will automatically fall back to off-policy mode. If needed, please use [swift infer](../Instruction/Inference-and-deployment.md) to pre-generate responses for the dataset
29+
30+
## Parameters
31+
32+
### GKD-specific Parameters
33+
34+
| Parameter | Type | Default | Description |
35+
|-----------|------|---------|-------------|
36+
| `--teacher_model` | str | Required | Path or model ID of the teacher model |
37+
| `--beta` | float | 0.5 | JSD divergence interpolation coefficient:<br>• 0.0: Forward KL<br>• 0.5: Symmetric JSD<br>• 1.0: Reverse KL |
38+
| `--lmbda` | float | 0.5 | On-Policy learning probability:<br>• 0.0: Pure Off-Policy<br>• 1.0: Pure On-Policy |
39+
| `--seq_kd` | bool | False | Use teacher-generated responses (not yet supported) |
40+
| `--temperature` | float | 0.9 | Temperature for sampling and loss computation |
41+
| `--sft_alpha` | float | 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
42+
| `--max_completion_length` | int | 512 | Maximum tokens for generation |
43+
44+
### Batch-related Parameters
45+
46+
Same as Megatron SFT, use the following parameters to control batch size:
47+
48+
| Parameter | Description |
49+
|-----------|-------------|
50+
| `--micro_batch_size` | Training batch size per GPU |
51+
| `--global_batch_size` | Global batch size: `micro_batch_size × dp_size × gradient_accumulation_steps` |
52+
53+
## Three Training Modes
54+
55+
GKD supports three training modes, controlled by `lmbda` and `seq_kd` parameters:
56+
57+
### Mode 1: On-Policy Learning
58+
- Trigger: `random() < lmbda` and `use_vllm=True`
59+
- Data source: Responses generated by the student model
60+
61+
### Mode 2: Sequential KD (Not Yet Supported)
62+
- Trigger: `random() >= lmbda` and `seq_kd=True`
63+
- Data source: Responses generated by the teacher model
64+
65+
### Mode 3: Off-Policy Learning
66+
- Trigger: Other cases
67+
- Data source: Labeled responses from the dataset
68+
69+
## Reference
70+
71+
For more parameters, please refer to [Command-line Parameters](./Command-line-parameters.md)
72+
73+
For training scripts, please refer to [Megatron GKD Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd)

docs/source_en/Megatron-SWIFT/Quick-start.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ms-swift incorporates Megatron's parallelization techniques to accelerate the tr
88
| Pre-training ||||||
99
| [Supervised Fine-Tuning](https://github.com/modelscope/ms-swift/tree/main/examples/megatron) ||||||
1010
| [GRPO](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/grpo) ||||||
11+
| [GKD](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/gkd) ||||||
1112
| [DPO](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/dpo) ||||||
1213
| [KTO](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/kto) ||||||
1314
| [RM](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/rm) ||||||
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
2+
NPROC_PER_NODE=8 \
3+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
4+
megatron rlhf \
5+
--rlhf_type gkd \
6+
--model Qwen/Qwen3-8B-Base \
7+
--teacher_model Qwen/Qwen3-32B \
8+
--train_type lora \
9+
--dataset AI-ModelScope/alpaca-gpt4-data-en#2000 AI-ModelScope/alpaca-gpt4-data-zh#2000 \
10+
--tensor_model_parallel_size 2 \
11+
--expert_model_parallel_size 1 \
12+
--pipeline_model_parallel_size 2 \
13+
--context_parallel_size 2 \
14+
--seq_kd false \
15+
--lmbda 1 \
16+
--beta 1 \
17+
--torch_dtype bfloat16 \
18+
--micro_batch_size 2 \
19+
--global_batch_size 16 \
20+
--max_epochs 1 \
21+
--lr 5e-6 \
22+
--log_interval 1 \
23+
--max_length 8192 \
24+
--max_completion_length 8192 \
25+
--attention_backend flash \
26+
--use_vllm true \
27+
--vllm_mode colocate \
28+
--vllm_gpu_memory_utilization 0.5 \
29+
--vllm_tensor_parallel_size 1 \
30+
--vllm_max_model_len 16384 \
31+
--sleep_level 1 \
32+
--offload_teacher_model true \
33+
--recompute_granularity selective \
34+
--finetune \
35+
--no_save_optim \
36+
--no_save_rng \
37+
--temperature 1.0 \
38+
--padding_free true \
39+
--sequence_parallel true

0 commit comments

Comments
 (0)