Skip to content

Commit e060ad8

Browse files
authored
[dpo] support dpo padding_free/logits_to_keep & dpo compat trl==0.18 (#4394)
1 parent 152f3a6 commit e060ad8

File tree

13 files changed

+195
-94
lines changed

13 files changed

+195
-94
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
- 🔥agent_template: Agent模板,确定如何将工具列表转换成system,如何从模型回复中提取toolcall,以及确定`{"role": "tool_call", "content": "xxx"}`, `{"role": "tool_response", "content": "xxx"}`的模板格式。可选为"react_en", "hermes", "glm4", "qwen_en", "toolbench"等,更多请查看[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/agent_template/__init__.py)。默认为None,根据模型类型进行选择。
7979
- norm_bbox: 控制如何缩放边界框(bbox)。选项为'norm1000'和'none'。'norm1000'表示将bbox坐标缩放至千分之一,而'none'表示不进行缩放。默认值为None,将根据模型自动选择。
8080
- use_chat_template: 使用chat模板或generation模板,默认为`True``swift pt`会自动设置为generation模板。
81-
- 🔥padding_free: 将一个batch中的数据进行展平而避免数据padding,从而降低显存占用并加快训练。默认为False。当前支持`swift pt/sft`
81+
- 🔥padding_free: 将一个batch中的数据进行展平而避免数据padding,从而降低显存占用并加快训练。默认为False。当前支持CPT/SFT/DPO/GRPO
8282
- 注意:使用padding_free请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)。(同packing)
8383
- 支持的多模态模型与多模态packing支持情况相同。相较于packing,padding_free不额外消耗时间和空间。
8484
- Megatron-SWIFT默认使用padding_free,即`qkv_format='thd'`,不需要额外设置。
@@ -88,7 +88,7 @@
8888
- 'all': 计算所有tokens的损失。
8989
- 'ignore_empty_think': 在`'default'`的基础上,忽略空的`'<think>\n\n</think>\n\n'`损失计算,具体请参考[此issue](https://github.com/modelscope/ms-swift/issues/4030)
9090
- 'react', 'hermes', 'qwen': 在`'default'`的基础上,将`tool_call`部分的loss权重调整为2。
91-
- sequence_parallel_size: 序列并行大小,默认是1。当前支持pt/sft/dpo。训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/sequence_parallel.sh)
91+
- sequence_parallel_size: 序列并行大小,默认是1。当前支持CPT/SFT/DPO/GRPO。训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/sequence_parallel.sh)
9292
- response_prefix: response的前缀字符,例如QwQ-32B将response_prefix设置为`'<think>\n'`。默认为None,根据模型自动设置。
9393
- 注意:若对deepseek-r1/qwq模型使用不包含`<think>...</think>`的数据集进行训练,请加在推理训练后模型时额外传入`--response_prefix ''`
9494
- template_backend: 选择template后端,可选为'swift'、'jinja',默认为'swift'。如果使用jinja,则使用transformers的`apply_chat_template`

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Hints:
7979
- 🔥agent_template: Agent template, which determines how to convert the list of tools into a system, how to extract tool calls from the model's response, and specifies the template format for `{"role": "tool_call", "content": "xxx"}` and `{"role": "tool_response", "content": "xxx"}`. Optional values include "react_en", "hermes", "glm4", "qwen_en", "toolbench", etc. For more details, please check [here](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/agent_template/__init__.py). The default value is None, meaning it will be selected based on the model type.
8080
- norm_bbox: Controls how to scale bounding boxes (bbox). Options are 'norm1000' and 'none'. 'norm1000' represents scaling bbox coordinates to one-thousandths, and 'none' means no scaling. Default is None, automatically selected based on the model.
8181
- use_chat_template: Use chat template or generation template, default is `True`. `swift pt` is automatically set to the generation template.
82-
- 🔥padding_free: Flattens the data in a batch to avoid padding, thereby reducing memory usage and accelerating training. Default is False. Currently supports `swift pt/sft`.
82+
- 🔥padding_free: Flattens the data in a batch to avoid padding, thereby reducing memory usage and accelerating training. Default is False. Currently supported in CPT/SFT/DPO/GRPO.
8383
- Note: When using `padding_free`, it should be combined with `--attn_impl flash_attn` and "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629). (Same as packing)
8484
- The supported multimodal models are the same as those supported for multimodal packing. Compared to packing, padding_free does not consume additional time or space.
8585
- Megatron-SWIFT uses `padding_free` by default, i.e., `qkv_format='thd'`, and no additional configuration is required.
@@ -89,7 +89,7 @@ Hints:
8989
- 'all': Calculate the loss for all tokens.
9090
- 'ignore_empty_think': On top of 'default', ignore the loss calculation for empty `'<think>\n\n</think>\n\n'`. See [this issue](https://github.com/modelscope/ms-swift/issues/4030) for more details.
9191
- `'react'`, `'hermes'`, `'qwen'`: On top of `'default'`, set the loss weight of the `tool_call` part to 2.
92-
- sequence_parallel_size: Sequence parallelism size, default is 1. Currently supported in pt/sft/dpo. The training script refers to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/sequence_parallel.sh).
92+
- sequence_parallel_size: Sequence parallelism size, default is 1. Currently supported in CPT/SFT/DPO/GRPO. The training script refers to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/sequence_parallel.sh).
9393
- response_prefix: The prefix character for the response, for example, setting the response_prefix to `'<think>\n'` for QwQ-32B. The default is None, and it is automatically set according to the model.
9494
- Note: If you are training the deepseek-r1/qwq model with a dataset that does not include `<think>...</think>`, please pass `--response_prefix ''` additionally when inferring after training.
9595
- template_backend: Selection of the template backend. Options are 'swift' and 'jinja', with 'swift' as the default. If using jinja, it applies transformer's `apply_chat_template`.

examples/train/padding_free/dpo.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# with padding_free: 4 * 47GiB, 1.90s/it
2+
# without padding_free: 4 * 57GiB 3.32s/it
3+
NPROC_PER_NODE=4 \
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
5+
swift rlhf \
6+
--rlhf_type dpo \
7+
--model Qwen/Qwen2.5-7B-Instruct \
8+
--train_type full \
9+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
10+
--torch_dtype bfloat16 \
11+
--num_train_epochs 1 \
12+
--per_device_train_batch_size 4 \
13+
--per_device_eval_batch_size 4 \
14+
--learning_rate 1e-5 \
15+
--gradient_accumulation_steps 1 \
16+
--eval_steps 100 \
17+
--save_steps 100 \
18+
--save_total_limit 2 \
19+
--logging_steps 5 \
20+
--max_length 8192 \
21+
--output_dir output \
22+
--warmup_ratio 0.05 \
23+
--save_only_model true \
24+
--dataloader_num_workers 4 \
25+
--dataset_num_proc 4 \
26+
--deepspeed zero3 \
27+
--attn_impl flash_attn \
28+
--save_only_model true \
29+
--padding_free true
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# with padding_free: 4 * 53GiB, 3.55s/it
2+
# without padding_free: 4 * 62GiB 4.41s/it
3+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
5+
NPROC_PER_NODE=4 \
6+
MAX_PIXELS=1003520 \
7+
swift rlhf \
8+
--rlhf_type dpo \
9+
--model Qwen/Qwen2.5-VL-7B-Instruct \
10+
--dataset 'swift/RLAIF-V-Dataset#20000' \
11+
--train_type full \
12+
--torch_dtype bfloat16 \
13+
--num_train_epochs 1 \
14+
--per_device_train_batch_size 4 \
15+
--per_device_eval_batch_size 4 \
16+
--learning_rate 1e-5 \
17+
--freeze_vit true \
18+
--gradient_accumulation_steps 1 \
19+
--eval_steps 100 \
20+
--save_steps 100 \
21+
--save_total_limit 2 \
22+
--deepspeed zero3 \
23+
--logging_steps 5 \
24+
--max_length 4096 \
25+
--output_dir output \
26+
--warmup_ratio 0.05 \
27+
--dataloader_num_workers 4 \
28+
--dataset_num_proc 4 \
29+
--attn_impl flash_attn \
30+
--save_only_model true \
31+
--padding_free true

examples/train/rlhf/dpo/lora.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# 24GiB
2+
# It is recommended to use padding_free. For more details, please refer to:
3+
# https://github.com/modelscope/ms-swift/blob/main/examples/train/padding_free/dpo.sh
24
CUDA_VISIBLE_DEVICES=0 \
35
swift rlhf \
46
--rlhf_type dpo \

swift/llm/argument/infer_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _init_ddp(self):
158158
if not is_dist():
159159
return
160160
assert not self.eval_human and not self.stream, (
161+
'In DDP scenarios, interactive interfaces and streaming output are not supported.'
161162
f'args.eval_human: {self.eval_human}, args.stream: {self.stream}')
162163
self._init_device()
163164
init_process_group(backend=self.ddp_backend, timeout=self.ddp_timeout)

swift/llm/train/sft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ def _get_data_collator(self):
8181
padding_to = args.max_length if args.train_type == 'longlora' else None
8282
return partial(template.data_collator, padding_to=padding_to)
8383

84-
@staticmethod
85-
def _save_val_dataset(output_dir: str, val_dataset):
86-
if is_master() and isinstance(val_dataset, HfDataset):
84+
def _save_val_dataset(self, val_dataset):
85+
args = self.args
86+
output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save')
87+
if is_master() and isinstance(val_dataset, HfDataset) and not args.val_dataset:
8788
os.makedirs(output_dir, exist_ok=True)
8889
val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl')
8990
append_to_jsonl(val_dataset_path, val_dataset.to_list())
@@ -216,8 +217,7 @@ def _stat_dataset(self, dataset: Union[HfDataset, PackingDataset]):
216217
def _encode_dataset(self, train_dataset, val_dataset):
217218
template = self.template
218219
args = self.args
219-
output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save')
220-
self._save_val_dataset(output_dir, val_dataset)
220+
self._save_val_dataset(val_dataset)
221221
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
222222
predict_with_generate = getattr(args, 'predict_with_generate', False)
223223
if not is_grpo:

swift/trainers/arguments.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,26 @@ def _new_checkpoint(*args, use_reentrant=None, **kwargs):
8787
except (ImportError, AttributeError):
8888
pass
8989

90+
@staticmethod
91+
def _patch_liger_kernel():
92+
# fix logits_to_keep
93+
from liger_kernel.transformers.model import loss_utils
94+
origin_LigerForCausalLMLoss = loss_utils.LigerForCausalLMLoss
95+
96+
def LigerForCausalLMLoss(hidden_states, *args, **kwargs):
97+
hidden_states = hidden_states.contiguous()
98+
return origin_LigerForCausalLMLoss(hidden_states, *args, **kwargs)
99+
100+
loss_utils.LigerForCausalLMLoss = LigerForCausalLMLoss
101+
logger.info('Patch liger_kernel successfully.')
102+
90103
def _init_liger(self):
91104
if self.use_liger_kernel:
92105
assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`'
106+
try:
107+
self._patch_liger_kernel()
108+
except Exception:
109+
pass
93110

94111
def __post_init__(self):
95112
if is_mp() and self.use_liger_kernel:

swift/trainers/mixin.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
3535
from swift.plugin import MeanMetric, compute_acc, extra_tuners
3636
from swift.tuners import SwiftModel
37-
from swift.utils import get_logger, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc
37+
from swift.utils import get_logger, is_mp, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc
3838
from swift.utils.torchacc_utils import ta_trim_graph
3939
from ..utils.torch_utils import get_device_count
4040
from .arguments import TrainingArguments
@@ -484,6 +484,36 @@ def _evalscope_eval(self):
484484
self.model.train()
485485
return eval_dict
486486

487+
def get_logits_to_keep(self, labels):
488+
if labels.shape[0] == 1 and not is_mp():
489+
# device_map may encounter device mismatch issues.
490+
loss_mask = (labels != -100)[0]
491+
labels = labels[:, loss_mask]
492+
labels = nn.functional.pad(labels, (1, 0), value=-100)
493+
logits_to_keep = nn.functional.pad(loss_mask[1:], (0, 1), value=True)
494+
else:
495+
logits_to_keep = labels.shape[-1] - ((labels != -100).int().argmax(-1).min().item()) + 1
496+
assert logits_to_keep > 0
497+
labels = labels[:, -logits_to_keep:]
498+
return labels, logits_to_keep
499+
500+
def get_cu_seqlens(self, position_ids, logits_to_keep) -> torch.Tensor:
501+
assert position_ids.shape[0] == 1
502+
position_ids = position_ids[0]
503+
indices = torch.arange(position_ids.shape[0], device=position_ids.device)
504+
cu_seqlens = torch.concat([
505+
indices[position_ids == 0],
506+
torch.tensor(position_ids.shape, device=position_ids.device),
507+
])
508+
res_cu_seqlens = cu_seqlens.clone()
509+
if isinstance(logits_to_keep, torch.Tensor):
510+
for i in range(cu_seqlens.shape[0] - 1):
511+
start, end = cu_seqlens[i], cu_seqlens[i + 1]
512+
res_cu_seqlens[i + 1:] -= (~logits_to_keep[start:end]).sum()
513+
elif isinstance(logits_to_keep, int):
514+
res_cu_seqlens[1:] -= position_ids.shape[0] + 1 - logits_to_keep
515+
return res_cu_seqlens
516+
487517
def get_batch_samples(self, *args, **kwargs):
488518
res = super().get_batch_samples(*args, **kwargs)
489519
from swift.trainers.sequence_parallel import sequence_parallel

0 commit comments

Comments
 (0)