Skip to content

Commit 1dda899

Browse files
committed
[megatron] update benchmark docs (#4991)
1 parent eb7e7d8 commit 1dda899

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ SWIFT引入了Megatron的并行技术来加速大模型的训练,包括数据
99
```shell
1010
# 推荐torch版本:2.5 / 2.6
1111
pip install pybind11
12+
1213
# transformer_engine
1314
# 若出现安装错误,可以参考该issue解决: https://github.com/modelscope/ms-swift/issues/3793
1415
pip install git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.3
16+
# 若以上命令报错也可以使用以下方式安装
17+
# pip install transformer_engine[pytorch]
1518

1619
# apex
1720
git clone https://github.com/NVIDIA/apex
@@ -134,8 +137,8 @@ I am a language model developed by swift, you can call me swift-robot. How can I
134137

135138
| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 |
136139
| -------- | ----------- | ---------- | ---------- |
137-
| 训练速度 | 2.93s/it | 6.02s/it | 24.30s/it |
138-
| 显存占用 | 8\*66GB | 8\*72GB | 8\*50GB |
140+
| 训练速度 | 2.95s/it | 6.02s/it | 24.30s/it |
141+
| 显存占用 | 8\*57GB | 8\*72GB | 8\*50GB |
139142

140143

141144
## 命令行参数
@@ -232,8 +235,8 @@ I am a language model developed by swift, you can call me swift-robot. How can I
232235
- 🔥sequence_parallel: 启动序列并行的优化器。默认为False。
233236
- 🔥context_parallel_size: cp数,默认为1。
234237
- tp_comm_overlap: 启用张量并行通信与GEMM(通用矩阵乘法)内核的重叠(降低通信耗时)。默认为False。
235-
- overlap_grad_reduce: 启用DDP中grad reduce操作的重叠(降低DP通信耗时)。默认为False。
236-
- overlap_param_gather: 启用分布式优化器中参数all-gather的重叠(降低DP通信耗时)。默认为False。
238+
- 🔥overlap_grad_reduce: 启用DDP中grad reduce操作的重叠(降低DP通信耗时)。默认为False。
239+
- 🔥overlap_param_gather: 启用分布式优化器中参数all-gather的重叠(降低DP通信耗时)。默认为False。
237240
- distributed_timeout_minutes: torch.distributed的timeout时间(单位为分钟),该参数失效,使用[基础参数](./命令行参数.md#基本参数)中的ddp_timeout控制,默认为300000分钟。
238241

239242
**日志参数**:

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ To use Megatron-SWIFT, in addition to installing the `swift` dependencies, you a
1010
```shell
1111
# Recommended PyTorch version: 2.5 / 2.6
1212
pip install pybind11
13+
1314
# transformer_engine
1415
# If an installation error occurs, you can refer to this issue for resolution: https://github.com/modelscope/ms-swift/issues/3793
1516
pip install git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.3
17+
# If the above command fails, you can also install it using the following command:
18+
# pip install transformer_engine[pytorch]
1619

1720
# apex
1821
git clone https://github.com/NVIDIA/apex
@@ -138,8 +141,8 @@ The speed comparison of full-parameter training for Dense/MoE models using `mega
138141

139142
| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 |
140143
| ---------------- | ----------- | --------------- | --------------- |
141-
| Training Speed | 2.93s/it | 6.02s/it | 24.30s/it |
142-
| GPU Memory Usage | 8\*66GB | 8\*72GB | 8\*50GB |
144+
| Training Speed | 2.95s/it | 6.02s/it | 24.30s/it |
145+
| GPU Memory Usage | 8\*57GB | 8\*72GB | 8\*50GB |
143146

144147
## Command Line Arguments
145148

@@ -239,8 +242,8 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
239242
- 🔥sequence_parallel: Enable sequence parallel optimization. Default is False.
240243
- 🔥context_parallel_size: CP (Context Parallelism) size, default is 1.
241244
- tp_comm_overlap: Overlap tensor parallel communication with GEMM (General Matrix Multiplication) kernels (to reduce communication time). Default is False.
242-
- overlap_grad_reduce: Overlap grad reduction operations in DDP (to reduce DP communication time). Default is False.
243-
- overlap_param_gather: Overlap all-gather of parameters in the distributed optimizer (to reduce DP communication time). Default is False.
245+
- 🔥overlap_grad_reduce: Overlap grad reduction operations in DDP (to reduce DP communication time). Default is False.
246+
- 🔥overlap_param_gather: Overlap all-gather of parameters in the distributed optimizer (to reduce DP communication time). Default is False.
244247
- distributed_timeout_minutes: The timeout duration for torch.distributed (in minutes). This parameter is deprecated and is now controlled by the `ddp_timeout` in the [Base Arguments](./Command-line-parameters.md#base-arguments), with a default value of 300000 minutes.
245248

246249
**Logging Parameters**:

examples/train/megatron/moe/moe.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# pp2ep4: 7 * 73GiB, 2.5s/it
2-
# tp2ep4: 8 * 65GiB, 3s/it
1+
# 8 * 57GiB, 2.95s/it
32
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
43
NPROC_PER_NODE=8 \
54
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
@@ -8,6 +7,7 @@ megatron sft \
87
--dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \
98
--split_dataset_ratio 0.01 \
109
--pipeline_model_parallel_size 2 \
10+
--decoder_last_pipeline_num_layers 11 \
1111
--expert_model_parallel_size 4 \
1212
--moe_grouped_gemm true \
1313
--moe_shared_expert_overlap true \
@@ -17,7 +17,9 @@ megatron sft \
1717
--packing true \
1818
--moe_permute_fusion true \
1919
--moe_router_dtype fp32 \
20-
--recompute_granularity selective \
20+
--recompute_granularity full \
21+
--recompute_method uniform \
22+
--recompute_num_layers 1 \
2123
--max_epochs 1 \
2224
--finetune true \
2325
--cross_entropy_loss_fusion true \
@@ -33,4 +35,4 @@ megatron sft \
3335
--no_save_optim true \
3436
--no_save_rng true \
3537
--sequence_parallel true \
36-
--use_flash_attn true
38+
--attention_backend flash

0 commit comments

Comments
 (0)