Skip to content

Commit ce426e1

Browse files
authored
[sft] support DFT (#5355)
* lint * update readme & optimize target_probs compute * fix scripts * position argument * compatible with sp
1 parent ca0abf7 commit ce426e1

File tree

11 files changed

+51
-9
lines changed

11 files changed

+51
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group:
7575

7676

7777
## 🎉 News
78+
- 🎁 2025.08.12: Support [Dynamic Fine-Tuning](https://arxiv.org/abs/2508.05629)(DFT) in SFT training, use parameter `--enable_dft_loss true`. Training scripts can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/dft.sh).
7879
- 🎁 2025.07.12: Deployment(pt/vLLM/SGLang) of Embedding models is supported, check [here](examples/deploy/embedding/client.py).
7980
- 🎁 2025.07.09: Megatron-SWIFT supports LoRA training. Compared to ms-swift, it achieves significant speedup on MoE models. Training scripts can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/lora).
8081
- 🎁 2025.06.23: Fine-tuning of reranker models is supported. Training scripts can be found here: [Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh).

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
- **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。
7272

7373
## 🎉 新闻
74+
- 🎁 2025.08.12: 支持在SFT训练中使用[Dynamic Fine-Tuning](https://arxiv.org/abs/2508.05629)(DFT),使用参数 `--enable_dft_loss true`。训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/dft.sh)
7475
- 🎁 2025.07.12: 支持部署Embedding模型的部署(pt/vLLM/SGLang), 查看[这里](examples/deploy/embedding/client.py).
7576
- 🎁 2025.07.09: Megatron-SWIFT支持LoRA训练。相比ms-swift,在MoE模型提速显著。训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/lora)
7677
- 🎁 2025.06.23: 支持Reranker模型训练,训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh)

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
- logging_steps: 日志打印间隔,默认为5。
166166
- router_aux_loss_coef: 用于moe模型训练时,设置 aux_loss 的权重,默认为`0.`
167167
- 注意:在"ms-swift==3.7.0",其默认为None,从config.json中读取,该行为在"ms-swift>=3.7.1"被修改。
168+
- enable_dft_loss: 是否在SFT训练中使用[DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss,默认为False。
168169
- logging_dir: tensorboard日志路径。默认为None,即设置为`f'{self.output_dir}/runs'`
169170
- predict_with_generate: 验证时使用生成式的方式,默认为False。
170171
- metric_for_best_model: 默认为None,即当`predict_with_generate`设置为False时,设置为'loss',否则设置为'rouge-l'(在PPO训练时,不进行默认值设置;GRPO训练设置为'reward')。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
168168
- logging_steps: Interval for logging, defaults to 5.
169169
- router_aux_loss_coef: Sets the weight of the aux_loss when training MoE models; default is `0.`
170170
- Note: In ms-swift == 3.7.0, the default is None and the value is read from config.json; this behavior was changed starting with ms-swift >= 3.7.1.
171+
- enable_dft_loss: Whether to use [DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss in SFT training, default is False.
171172
- logging_dir: The path for TensorBoard logs. Defaults to None, which means it is set to `f'{self.output_dir}/runs'`.
172173
- predict_with_generate: Whether to use generative method during validation, default is False.
173174
- metric_for_best_model: Default is None, which means that when predict_with_generate is set to False, it is set to 'loss'; otherwise, it is set to 'rouge-l' (during PPO training, the default value is not set; in GRPO training, it is set to 'reward').

examples/train/full/dft.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# 4*80G
2+
# exp: https://github.com/modelscope/ms-swift/pull/5355
3+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
4+
NPROC_PER_NODE=4 \
5+
swift sft \
6+
--model Qwen/Qwen2.5-Math-1.5B \
7+
--train_type full \
8+
--dataset AI-MO/NuminaMath-CoT#100000 \
9+
--torch_dtype bfloat16 \
10+
--enable_dft_loss true \
11+
--num_train_epochs 1 \
12+
--per_device_train_batch_size 8 \
13+
--learning_rate 5e-5 \
14+
--gradient_accumulation_steps 32 \
15+
--save_total_limit 2 \
16+
--logging_steps 5 \
17+
--max_length 2048 \
18+
--output_dir output \
19+
--system 'You are a helpful assistant.' \
20+
--warmup_ratio 0.1 \
21+
--deepspeed zero2 \
22+
--dataloader_num_workers 4

swift/plugin/loss.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from transformers.utils import strtobool
1313

1414

15-
def per_token_loss_func(outputs, labels, **kwargs):
15+
def per_token_loss_func(outputs, labels, enable_dft_loss, **kwargs):
1616
logits = outputs.logits
1717
# Upcast to float if we need to compute the loss to avoid potential precision issues
1818
logits = logits.float()
@@ -23,6 +23,10 @@ def per_token_loss_func(outputs, labels, **kwargs):
2323
# Enable model parallelism
2424
labels = labels.to(logits.device)
2525
loss = F.cross_entropy(logits, labels, ignore_index=-100, reduction='none')
26+
if enable_dft_loss:
27+
with torch.no_grad():
28+
target_probs = torch.exp(-loss)
29+
loss *= target_probs
2630
return loss
2731

2832

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class TrainArgumentsMixin:
3535
logging_first_step: bool = True
3636
logging_steps: int = 5
3737
router_aux_loss_coef: float = 0.
38+
enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629
3839

3940
weight_decay: float = 0.1
4041
adam_beta2: float = 0.95

swift/trainers/sequence_parallel/ring_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,10 @@ def prepare_trainer(self, trainer):
216216
trainer.ring_attention = self
217217

218218
if trainer.__class__.__name__ == 'Seq2SeqTrainer':
219+
enable_dft_loss = trainer.args.enable_dft_loss
219220
trainer._origin_prepare_inputs = trainer._prepare_inputs
220221
trainer._prepare_inputs = MethodType(partial(_prepare_inputs, sp_instance=self), trainer)
221-
trainer.compute_loss_func = partial(loss_scale_sp_func, sp_instance=self)
222+
trainer.compute_loss_func = partial(loss_scale_sp_func, sp_instance=self, enable_dft_loss=enable_dft_loss)
222223

223224
elif trainer.__class__.__name__ == 'DPOTrainer':
224225
trainer._origin_prepare_inputs = trainer._prepare_inputs

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,10 @@ def prepare_trainer(self, trainer):
326326

327327
trainer.ulysses = self
328328
if trainer.__class__.__name__ == 'Seq2SeqTrainer':
329+
enable_dft_loss = trainer.args.enable_dft_loss
329330
trainer._origin_prepare_inputs = trainer._prepare_inputs
330331
trainer._prepare_inputs = MethodType(partial(_prepare_inputs, sp_instance=self), trainer)
331-
trainer.compute_loss_func = partial(loss_scale_sp_func, sp_instance=self)
332+
trainer.compute_loss_func = partial(loss_scale_sp_func, sp_instance=self, enable_dft_loss=enable_dft_loss)
332333

333334
elif trainer.__class__.__name__ == 'DPOTrainer':
334335
trainer._origin_prepare_inputs = trainer._prepare_inputs

swift/trainers/sequence_parallel/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def loss_scale_sp_func(outputs,
125125
loss_scale=None,
126126
num_items_in_batch=None,
127127
sp_instance=None,
128+
enable_dft_loss=False,
128129
**kwargs) -> torch.Tensor:
129130
"""Common loss function for sequence parallel training"""
130131
if hasattr(outputs, 'logits'):
@@ -146,6 +147,10 @@ def loss_scale_sp_func(outputs,
146147
else:
147148
loss_fct = CrossEntropyLoss(reduction='none')
148149
loss = loss_fct(logits, labels)
150+
if enable_dft_loss:
151+
with torch.no_grad():
152+
target_probs = torch.exp(-loss)
153+
loss *= target_probs
149154
if loss_scale is not None:
150155
loss_scale = loss_scale.flatten().to(device)
151156
loss = (loss_scale * loss)

0 commit comments

Comments
 (0)