Skip to content

Commit f9afe2d

Browse files
authored
Support ring attention for llm sft/dpo/grpo (packing/padding_free only). (#4814)
1 parent 4f93387 commit f9afe2d

File tree

16 files changed

+1126
-621
lines changed

16 files changed

+1126
-621
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
- 'all': 计算所有tokens的损失。
9595
- 'ignore_empty_think': 在`'default'`的基础上,忽略空的`'<think>\n\n</think>\n\n'`损失计算,具体请参考[此issue](https://github.com/modelscope/ms-swift/issues/4030)
9696
- 'react', 'hermes', 'qwen': 在`'default'`的基础上,将`tool_call`部分的loss权重调整为2。
97-
- sequence_parallel_size: 序列并行大小,默认是1。当前支持CPT/SFT/DPO/GRPO。训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/sequence_parallel.sh)
97+
- sequence_parallel_size: 序列并行大小,默认是1。当前支持CPT/SFT/DPO/GRPO。训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/ulysses/sequence_parallel.sh)
9898
- response_prefix: response的前缀字符,例如QwQ-32B将response_prefix设置为`'<think>\n'`。默认为None,根据模型自动设置。
9999
- 注意:若对deepseek-r1/qwq模型使用不包含`<think>...</think>`的数据集进行训练,请加在推理训练后模型时额外传入`--response_prefix ''`
100100
- template_backend: 选择template后端,可选为'swift'、'jinja',默认为'swift'。如果使用jinja,则使用transformers的`apply_chat_template`

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ Hints:
9595
- 'all': Calculate the loss for all tokens.
9696
- 'ignore_empty_think': On top of 'default', ignore the loss calculation for empty `'<think>\n\n</think>\n\n'`. See [this issue](https://github.com/modelscope/ms-swift/issues/4030) for more details.
9797
- `'react'`, `'hermes'`, `'qwen'`: On top of `'default'`, set the loss weight of the `tool_call` part to 2.
98-
- sequence_parallel_size: Sequence parallelism size, default is 1. Currently supported in CPT/SFT/DPO/GRPO. The training script refers to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/sequence_parallel.sh).
98+
- sequence_parallel_size: Sequence parallelism size, default is 1. Currently supported in CPT/SFT/DPO/GRPO. The training script refers to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text/ulysses/sequence_parallel.sh).
9999
- response_prefix: The prefix character for the response, for example, setting the response_prefix to `'<think>\n'` for QwQ-32B. The default is None, and it is automatically set according to the model.
100100
- Note: If you are training the deepseek-r1/qwq model with a dataset that does not include `<think>...</think>`, please pass `--response_prefix ''` additionally when inferring after training.
101101
- template_backend: Selection of the template backend. Options are 'swift' and 'jinja', with 'swift' as the default. If using jinja, it applies transformer's `apply_chat_template`.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Env: 4 * A100
2+
# Max Length: 65536
3+
# GPU Memory: 4 * 38GiB, Training Speed 30s/it
4+
NPROC_PER_NODE=4 \
5+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
6+
SEQUENCE_PARALLEL_IMPL=ring_attention \
7+
RING_HEAD_STRIDE=2 \
8+
swift sft \
9+
--model Qwen/Qwen2.5-3B-Instruct \
10+
--train_type full \
11+
--dataset 'AI-ModelScope/LongAlpaca-12k' \
12+
--torch_dtype bfloat16 \
13+
--per_device_train_batch_size 1 \
14+
--per_device_eval_batch_size 1 \
15+
--learning_rate 1e-5 \
16+
--gradient_accumulation_steps 8 \
17+
--packing true \
18+
--rope_scaling yarn \
19+
--max_length 65536 \
20+
--eval_steps 50 \
21+
--save_steps 50 \
22+
--logging_steps 5 \
23+
--warmup_ratio 0.05 \
24+
--dataloader_num_workers 8 \
25+
--dataset_num_proc 8 \
26+
--save_total_limit 2 \
27+
--save_only_model true \
28+
--output_dir output/Qwen2.5-3B-Instruct \
29+
--deepspeed zero3 \
30+
--attn_impl flash_attn \
31+
--sequence_parallel_size 4
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Env: 4 * A100
2+
# Max Length: 256000
3+
# GPU Memory: 4 * 42GiB, Training Speed 43s/it
4+
NPROC_PER_NODE=4 \
5+
CELOSS_PARALLEL_SIZE=2048 \
6+
SEQUENCE_PARALLEL_IMPL=ring_attention \
7+
swift sft \
8+
--model Qwen/Qwen2.5-7B-Instruct \
9+
--train_type lora \
10+
--dataset 'AI-ModelScope/LongAlpaca-12k' \
11+
--torch_dtype bfloat16 \
12+
--per_device_train_batch_size 1 \
13+
--per_device_eval_batch_size 1 \
14+
--learning_rate 1e-5 \
15+
--gradient_accumulation_steps 2 \
16+
--packing true \
17+
--rope_scaling yarn \
18+
--max_length 256000 \
19+
--eval_steps 200 \
20+
--save_steps 200 \
21+
--logging_steps 5 \
22+
--warmup_ratio 0.05 \
23+
--dataloader_num_workers 8 \
24+
--dataset_num_proc 8 \
25+
--save_total_limit 2 \
26+
--use_liger_kernel true \
27+
--save_only_model true \
28+
--deepspeed zero3_offload \
29+
--attn_impl flash_attn \
30+
--sequence_parallel_size 4
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Env: 4 * A100
2+
# GPU Memory: 4 * 52GiB, Training Speed 4s/it
3+
NPROC_PER_NODE=4 \
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
5+
swift rlhf \
6+
--rlhf_type dpo \
7+
--model Qwen/Qwen2.5-7B-Instruct \
8+
--train_type full \
9+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
10+
--torch_dtype bfloat16 \
11+
--num_train_epochs 1 \
12+
--per_device_train_batch_size 1 \
13+
--per_device_eval_batch_size 1 \
14+
--learning_rate 1e-5 \
15+
--gradient_accumulation_steps 4 \
16+
--eval_steps 100 \
17+
--save_steps 100 \
18+
--save_total_limit 2 \
19+
--logging_steps 5 \
20+
--max_length 8192 \
21+
--output_dir output \
22+
--warmup_ratio 0.05 \
23+
--save_only_model true \
24+
--dataloader_num_workers 4 \
25+
--dataset_num_proc 4 \
26+
--deepspeed zero3 \
27+
--attn_impl flash_attn \
28+
--padding_free true \
29+
--sequence_parallel_size 2
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
NPROC_PER_NODE=4 \
2+
PYTORCH_CUDA_ALLOC_CONF='' \
3+
SEQUENCE_PARALLEL_IMPL=ring_attention \
4+
swift rlhf \
5+
--rlhf_type grpo \
6+
--model Qwen/Qwen2.5-7B \
7+
--train_type full \
8+
--use_vllm true \
9+
--vllm_mode colocate \
10+
--vllm_gpu_memory_utilization 0.5 \
11+
--vllm_max_model_len 2048 \
12+
--vllm_tensor_parallel_size 4 \
13+
--dataset AI-MO/NuminaMath-TIR#5000 \
14+
--torch_dtype bfloat16 \
15+
--num_train_epochs 1 \
16+
--max_length 2048 \
17+
--per_device_train_batch_size 4 \
18+
--per_device_eval_batch_size 4 \
19+
--gradient_accumulation_steps 8 \
20+
--eval_steps 1000 \
21+
--save_steps 1000 \
22+
--learning_rate 1e-6 \
23+
--save_total_limit 2 \
24+
--logging_steps 5 \
25+
--output_dir output \
26+
--warmup_ratio 0.05 \
27+
--dataloader_num_workers 4 \
28+
--max_completion_length 1024 \
29+
--reward_funcs accuracy format \
30+
--num_generations 4 \
31+
--system examples/train/grpo/prompt.txt \
32+
--deepspeed zero3 \
33+
--temperature 1.0 \
34+
--top_p 1.0 \
35+
--top_k 80 \
36+
--attn_impl flash_attn \
37+
--log_completions true \
38+
--async_generate false \
39+
--offload_optimizer true \
40+
--offload_model true \
41+
--padding_free true \
42+
--sequence_parallel_size 4 \
43+
--gc_collect_after_offload true \
44+
--dataloader_drop_last true \
45+
--sleep_level 1

0 commit comments

Comments
 (0)