Skip to content

Commit 67458df

Browse files
authored
[megatron] Support dpo lora (#4913)
1 parent 68a6f80 commit 67458df

File tree

9 files changed

+101
-43
lines changed

9 files changed

+101
-43
lines changed

docs/source/BestPractices/Qwen3最佳实践.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ megatron sft \
339339
--load Qwen3-30B-A3B-Base-mcore \
340340
--dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \
341341
--split_dataset_ratio 0.01 \
342-
--tensor_model_parallel_size 2 \
342+
--pipeline_model_parallel_size 2 \
343343
--expert_model_parallel_size 8 \
344344
--moe_grouped_gemm true \
345345
--moe_shared_expert_overlap true \
@@ -366,7 +366,7 @@ megatron sft \
366366
--no_save_optim true \
367367
--no_save_rng true \
368368
--sequence_parallel true \
369-
--use_flash_attn true
369+
--attention_backend flash
370370
```
371371

372372
训练loss图(部分):

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
4444

4545
首先,我们需要将HF格式的权重转为Megatron格式:
4646
- 若出现OOM,将`CUDA_VISIBLE_DEVICES=0`删除即可。
47-
- "ms-swift>=3.6"推荐增加`--test_convert_precision true`参数测试转换精度。
4847
```shell
4948
CUDA_VISIBLE_DEVICES=0 \
5049
swift export \
5150
--model Qwen/Qwen2.5-7B-Instruct \
5251
--to_mcore true \
5352
--torch_dtype bfloat16 \
54-
--output_dir Qwen2.5-7B-Instruct-mcore
53+
--output_dir Qwen2.5-7B-Instruct-mcore \
54+
--test_convert_precision true
5555
```
5656

5757
然后,使用以下脚本进行训练,训练所需显存资源为2*80GiB:
@@ -93,14 +93,14 @@ megatron sft \
9393
最后,将Megatron格式权重转为HF格式:
9494
- 注意:`--mcore_model`请指向`iter_xxx`的上级目录。默认会使用`latest_checkpointed_iteration.txt`中对应的checkpoint。
9595
- 若出现OOM,将`CUDA_VISIBLE_DEVICES=0`删除即可。
96-
- "ms-swift>=3.6"推荐增加`--test_convert_precision true`参数测试转换精度。
9796
```shell
9897
CUDA_VISIBLE_DEVICES=0 \
9998
swift export \
10099
--mcore_model megatron_output/Qwen2.5-7B-Instruct/vx-xxx \
101100
--to_hf true \
102101
--torch_dtype bfloat16 \
103-
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf
102+
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \
103+
--test_convert_precision true
104104
```
105105

106106
我们对生成的HF格式权重进行推理:
@@ -172,10 +172,10 @@ MCore转换HF脚本:
172172
```bash
173173
CUDA_VISIBLE_DEVICES=0 \
174174
swift export \
175-
--mcore_adapters /mnt/nas2/huangjintao.hjt/work/llmscope/megatron_output/Qwen3-30B-A3B/v5-20250710-204630 \
175+
--mcore_adapters megatron_output/Qwen2.5-7B-Instruct/vx-xxx \
176176
--to_hf true \
177177
--torch_dtype bfloat16 \
178-
--output_dir /mnt/nas2/huangjintao.hjt/work/llmscope/megatron_output/Qwen3-30B-A3B/v5-20250710-204630-hf \
178+
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \
179179
--test_convert_precision true
180180
```
181181
- 注意:`mcore_adapters`文件夹中包含`args.json`文件,转换过程中会读取文件中`mcore_model`和LoRA相关的参数信息,并将`mcore_model``mcore_adapters`进行merge-lora成完整权重,最终转换成HF格式权重。
@@ -402,6 +402,7 @@ lora训练:
402402
- adapter_load: 加载adapter的权重路径,默认为None。
403403
- 🔥target_modules: 指定lora模块的后缀, 默认为`['all-linear']`
404404
- 🔥target_regex: 指定lora模块的regex表达式,默认为`None`。如果该值传入,则target_modules参数失效。
405+
- 🔥modules_to_save: 在已附加tuner后,额外指定一部分原模型模块参与训练和存储。默认为`[]`
405406
- 🔥lora_rank: 默认为`8`
406407
- 🔥lora_alpha: 默认为`32`
407408
- lora_dropout: 默认为`0.05`

docs/source_en/BestPractices/Qwen3-Best-Practice.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ megatron sft \
343343
--load Qwen3-30B-A3B-Base-mcore \
344344
--dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \
345345
--split_dataset_ratio 0.01 \
346-
--tensor_model_parallel_size 2 \
346+
--pipeline_model_parallel_size 2 \
347347
--expert_model_parallel_size 8 \
348348
--moe_grouped_gemm true \
349349
--moe_shared_expert_overlap true \
@@ -370,7 +370,7 @@ megatron sft \
370370
--no_save_optim true \
371371
--no_save_rng true \
372372
--sequence_parallel true \
373-
--use_flash_attn true
373+
--attention_backend flash
374374
```
375375

376376

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ This section introduces a quick start example for fine-tuning the self-awareness
4545

4646
First, we need to convert the weights from HF (Hugging Face) format to Megatron format:
4747
- If you encounter OOM, simply remove `CUDA_VISIBLE_DEVICES=0`.
48-
- For "ms-swift>=3.6", it is recommended to add the `--test_convert_precision true` parameter to test conversion precision.
4948
```shell
5049
CUDA_VISIBLE_DEVICES=0 \
5150
swift export \
5251
--model Qwen/Qwen2.5-7B-Instruct \
5352
--to_mcore true \
5453
--torch_dtype bfloat16 \
55-
--output_dir Qwen2.5-7B-Instruct-mcore
54+
--output_dir Qwen2.5-7B-Instruct-mcore \
55+
--test_convert_precision true
5656
```
5757

5858
Next, use the following script to start training. The required GPU memory resources are 2*80GiB:
@@ -94,15 +94,15 @@ megatron sft \
9494
Finally, convert the Megatron format weights back to HF format:
9595
- Note: Please point `--mcore_model` to the parent directory of `iter_xxx`. By default, the corresponding checkpoint from `latest_checkpointed_iteration.txt` will be used.
9696
- If you encounter OOM, simply remove `CUDA_VISIBLE_DEVICES=0`.
97-
- For "ms-swift>=3.6", it is recommended to add the `--test_convert_precision true` parameter to test conversion precision.
9897

9998
```shell
10099
CUDA_VISIBLE_DEVICES=0 \
101100
swift export \
102101
--mcore_model megatron_output/Qwen2.5-7B-Instruct/vx-xxx \
103102
--to_hf true \
104103
--torch_dtype bfloat16 \
105-
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf
104+
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \
105+
--test_convert_precision true
106106
```
107107

108108
We then perform inference on the generated HF format weights:
@@ -423,6 +423,7 @@ LoRA Training:
423423
- adapter_load: Path to the adapter weights to be loaded. Default is `None`.
424424
- 🔥target_modules: Suffixes of modules to apply LoRA to. Default is `['all-linear']`.
425425
- 🔥target_regex: Regex expression to specify LoRA modules. Default is `None`. If this value is provided, the `target_modules` parameter will be ignored.
426+
- 🔥modules_to_save: After attaching a tuner, explicitly specifies additional original model modules to participate in training and storage. The default is `[]`.
426427
- 🔥lora_rank: Default is `8`.
427428
- 🔥lora_alpha: Default is `32`.
428429
- lora_dropout: Default is `0.05`.

examples/train/megatron/lora/dpo.sh

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# 2 * 55GiB; 4.50s/it
2+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
3+
NPROC_PER_NODE=2 \
4+
CUDA_VISIBLE_DEVICES=0,1 \
5+
megatron rlhf \
6+
--rlhf_type dpo \
7+
--load Qwen3-30B-A3B-Base-mcore \
8+
--dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \
9+
--train_type lora \
10+
--lora_rank 8 \
11+
--lora_alpha 32 \
12+
--target_modules all-linear \
13+
--split_dataset_ratio 0.01 \
14+
--expert_model_parallel_size 2 \
15+
--moe_grouped_gemm true \
16+
--moe_shared_expert_overlap true \
17+
--moe_aux_loss_coeff 0.01 \
18+
--micro_batch_size 8 \
19+
--global_batch_size 16 \
20+
--recompute_granularity full \
21+
--recompute_method uniform \
22+
--recompute_num_layers 1 \
23+
--max_epochs 1 \
24+
--finetune true \
25+
--cross_entropy_loss_fusion true \
26+
--lr 1e-4 \
27+
--lr_warmup_fraction 0.05 \
28+
--min_lr 1e-5 \
29+
--save megatron_output/Qwen3-30B-A3B-Base \
30+
--eval_interval 100 \
31+
--save_interval 100 \
32+
--max_length 8192 \
33+
--num_workers 8 \
34+
--dataset_num_proc 8 \
35+
--no_save_optim true \
36+
--no_save_rng true \
37+
--sequence_parallel true \
38+
--attention_backend flash \
39+
--beta 0.1 \
40+
--rpo_alpha 1 \
41+
--loss_type sigmoid

examples/train/megatron/moe/qwen3_moe.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ megatron sft \
1010
--load Qwen3-30B-A3B-Base-mcore \
1111
--dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \
1212
--split_dataset_ratio 0.01 \
13-
--tensor_model_parallel_size 2 \
13+
--pipeline_model_parallel_size 2 \
1414
--expert_model_parallel_size 8 \
1515
--moe_grouped_gemm true \
1616
--moe_shared_expert_overlap true \

swift/megatron/trainers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
157157
def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs):
158158

159159
def new_model_provider_func(*args, **kwargs):
160-
model = model_provider_func(*args, **kwargs)
161-
prepare_mcore_model(model)
162-
return model
160+
self.unwrapped_model = model_provider_func(*args, **kwargs)
161+
self.peft_model = prepare_mcore_model(self.unwrapped_model)
162+
return self.unwrapped_model
163163

164164
with self._patch_load_state_dict():
165165
model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer(

swift/megatron/trainers/dpo_trainer.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from collections import namedtuple
3+
from contextlib import contextmanager, nullcontext
34
from functools import partial
45

56
import torch
@@ -40,13 +41,16 @@ def __init__(self, args):
4041

4142
def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs):
4243
args = get_args()
43-
ref_model = get_model(model_provider_func, model_type)
44-
if args.ref_load is None:
45-
args.ref_load = args.load
46-
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
47-
ref_model, None, None, load_arg='ref_load')
48-
self.ref_model = ref_model[0]
49-
self.ref_model.eval()
44+
if args.train_type == 'full':
45+
ref_model = get_model(model_provider_func, model_type)
46+
if args.ref_load is None:
47+
args.ref_load = args.load
48+
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
49+
ref_model, None, None, load_arg='ref_load')
50+
self.ref_model = ref_model[0]
51+
self.ref_model.eval()
52+
else:
53+
self.ref_model = None
5054
return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs)
5155

5256
@staticmethod
@@ -78,8 +82,7 @@ def _forward_step_helper(model, inputs):
7882

7983
return output_tensor
8084

81-
def ref_forward(self, data_iterator):
82-
ref_model = unwrap_model(self.ref_model)
85+
def ref_forward(self, ref_model, data_iterator):
8386
with self.stimer(bdata=True):
8487
data = get_batch(data_iterator)
8588
data.pop('loss_scale', None)
@@ -144,13 +147,25 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab
144147
loss = loss / mpu.get_context_parallel_world_size()
145148
return (loss, reporting_metric)
146149

150+
@contextmanager
151+
def null_ref_context(self):
152+
args = get_args()
153+
if args.train_type == 'full':
154+
context = nullcontext()
155+
ref_model = unwrap_model(self.ref_model)
156+
else:
157+
context = self.peft_model.disable_adapter()
158+
ref_model = self.unwrapped_model
159+
with context:
160+
yield ref_model
161+
147162
def _replace_data_iterator(self, data_iterator):
148163
args = get_args()
149164
num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size())
150165
res = []
151-
for i in range(num_iters_per_step):
152-
with torch.no_grad():
153-
res.append(self.ref_forward(data_iterator))
166+
with torch.no_grad(), self.null_ref_context() as ref_model:
167+
for i in range(num_iters_per_step):
168+
res.append(self.ref_forward(ref_model, data_iterator))
154169
return iter(res)
155170

156171
def forward_step(self, data_iterator, model):

swift/megatron/tuners/lora.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,21 +225,21 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
225225

226226
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
227227
previous_dtype = x.dtype
228-
if self.disable_adapters:
229-
if self.merged:
230-
self.unmerge()
231-
result, bias = self.base_layer(x, *args, **kwargs)
232-
elif self.merged:
233-
result, bias = self.base_layer(x, *args, **kwargs)
234-
else:
235-
if isinstance(self.base_layer, TELayerNormColumnParallelLinear):
236-
self.base_layer.return_layernorm_output = True
237-
result, bias = self.base_layer(x, *args, **kwargs)
238-
result, x = result # ln_out
239-
elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)):
228+
if self.disable_adapters and self.merged:
229+
self.unmerge()
230+
231+
if isinstance(self.base_layer, TELayerNormColumnParallelLinear):
232+
if self.disable_adapters or self.merged:
233+
self.base_layer.return_layernorm_output = False
240234
result, bias = self.base_layer(x, *args, **kwargs)
241235
else:
242-
raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}')
236+
self.base_layer.return_layernorm_output = True
237+
(result, x), bias = self.base_layer(x, *args, **kwargs)
238+
elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)):
239+
result, bias = self.base_layer(x, *args, **kwargs)
240+
else:
241+
raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}')
242+
if not self.disable_adapters and not self.merged:
243243
for active_adapter in self.active_adapters:
244244
if active_adapter not in self.lora_A.keys():
245245
continue

0 commit comments

Comments
 (0)