Skip to content

Commit a126464

Browse files
authored
[dist] fix: make OptimizerState EP-dim aware to fix its dcp saving (#228)
Similar to ModelState, OptimizerState also needs to be EP-dim aware so that dcp can save it properly. Now it has same workflow with ModelState: * before saving with dcp, restore EP dim * after loading state dict from dcp, drop EP dim
1 parent 74ca8d7 commit a126464

File tree

7 files changed

+322
-100
lines changed

7 files changed

+322
-100
lines changed

.github/workflows/gpu_unit_tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ jobs:
7777
- name: Run models tests
7878
run: |
7979
pytest -s -x tests/models/test_models_patch.py
80+
- name: Run e2e dcp save and load test
81+
run: |
82+
pytest -s -x tests/checkpoints/test_trainer_saveload.py
8083
8184
8285
cleanup:
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"architectures": [
3+
"Qwen3MoeForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 151643,
8+
"decoder_sparse_step": 1,
9+
"eos_token_id": 151645,
10+
"head_dim": 128,
11+
"hidden_act": "silu",
12+
"hidden_size": 2048,
13+
"initializer_range": 0.02,
14+
"intermediate_size": 6144,
15+
"max_position_embeddings": 262144,
16+
"max_window_layers": 48,
17+
"mlp_only_layers": [],
18+
"model_type": "qwen3_moe",
19+
"moe_intermediate_size": 768,
20+
"norm_topk_prob": true,
21+
"num_attention_heads": 32,
22+
"num_experts": 128,
23+
"num_experts_per_tok": 8,
24+
"num_hidden_layers": 4,
25+
"num_key_value_heads": 4,
26+
"output_router_logits": false,
27+
"rms_norm_eps": 1e-06,
28+
"rope_scaling": null,
29+
"rope_theta": 10000000,
30+
"router_aux_loss_coef": 0.001,
31+
"sliding_window": null,
32+
"tie_word_embeddings": false,
33+
"torch_dtype": "bfloat16",
34+
"transformers_version": "4.51.3",
35+
"use_cache": true,
36+
"use_sliding_window": false,
37+
"vocab_size": 151936
38+
}

tests/checkpoints/ep4.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
model:
2+
# model_path: ./qwen3moe_4layers_merged
3+
config_path: configs/model_configs/qwen/qwen3_moe_30a3b_4_layers.json
4+
tokenizer_path: Qwen/Qwen3-30B-A3B
5+
weight_path: None
6+
moe_implementation: fused
7+
attn_implementation: flash_attention_2
8+
9+
data:
10+
train_path: dummy
11+
max_seq_len: 128
12+
13+
train:
14+
output_dir: ./test_trainer_saveload_ep4
15+
data_parallel_mode: fsdp2
16+
expert_parallel_size: 4
17+
enable_full_shard: true
18+
init_device: meta
19+
global_batch_size: 8
20+
micro_batch_size: 1
21+
rmpad: false
22+
rmpad_with_pos_ids: true
23+
dyn_bsz_margin: 0
24+
lr: 3.0e-4
25+
lr_warmup_ratio: 0.007
26+
lr_decay_style: constant
27+
lr_decay_ratio: 1.0
28+
weight_decay: 0.01
29+
max_grad_norm: 1.0
30+
use_wandb: false
31+
enable_profiling: false
32+
max_steps: 5
33+
ckpt_manager: dcp
34+
save_async: true

tests/checkpoints/ep8.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
model:
2+
# model_path: ./qwen3moe_4layers_merged
3+
config_path: configs/model_configs/qwen/qwen3_moe_30a3b_4_layers.json
4+
tokenizer_path: Qwen/Qwen3-30B-A3B
5+
weight_path: None
6+
moe_implementation: fused
7+
attn_implementation: flash_attention_2
8+
9+
data:
10+
train_path: dummy
11+
max_seq_len: 128
12+
13+
train:
14+
output_dir: ./test_trainer_saveload_ep8
15+
data_parallel_mode: fsdp2
16+
expert_parallel_size: 8
17+
enable_full_shard: true
18+
init_device: meta
19+
global_batch_size: 8
20+
micro_batch_size: 1
21+
rmpad: false
22+
rmpad_with_pos_ids: true
23+
dyn_bsz_margin: 0
24+
lr: 3.0e-4
25+
lr_warmup_ratio: 0.007
26+
lr_decay_style: constant
27+
lr_decay_ratio: 1.0
28+
weight_decay: 0.01
29+
max_grad_norm: 1.0
30+
use_wandb: false
31+
enable_profiling: false
32+
max_steps: 5
33+
ckpt_manager: dcp
34+
save_async: true

tests/checkpoints/no_ep.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
model:
2+
# model_path: ./qwen3moe_4layers_merged
3+
config_path: configs/model_configs/qwen/qwen3_moe_30a3b_4_layers.json
4+
tokenizer_path: Qwen/Qwen3-30B-A3B
5+
weight_path: None
6+
moe_implementation: fused
7+
attn_implementation: flash_attention_2
8+
9+
data:
10+
train_path: dummy
11+
max_seq_len: 128
12+
13+
train:
14+
output_dir: ./test_trainer_saveload_no_ep
15+
data_parallel_mode: fsdp2
16+
expert_parallel_size: 1
17+
enable_full_shard: true
18+
init_device: meta
19+
global_batch_size: 8
20+
micro_batch_size: 1
21+
rmpad: false
22+
rmpad_with_pos_ids: true
23+
dyn_bsz_margin: 0
24+
lr: 3.0e-4
25+
lr_warmup_ratio: 0.007
26+
lr_decay_style: constant
27+
lr_decay_ratio: 1.0
28+
weight_decay: 0.01
29+
max_grad_norm: 1.0
30+
use_wandb: false
31+
enable_profiling: false
32+
max_steps: 5
33+
ckpt_manager: dcp
34+
save_async: true

tests/utils/test_trainer_saveload.py renamed to tests/checkpoints/test_trainer_saveload.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import subprocess
34
from dataclasses import asdict, dataclass, field
45
from typing import Any, Dict, List, Optional
56

@@ -8,7 +9,7 @@
89
from tqdm import trange
910

1011
from veomni.checkpoint import build_checkpointer
11-
from veomni.data import build_dummy_dataset, build_streaming_dataloader
12+
from veomni.data import build_dataloader, build_dummy_dataset
1213
from veomni.distributed.offloading import build_activation_offloading_context
1314
from veomni.distributed.parallel_state import get_parallel_state, init_parallel_state
1415
from veomni.distributed.torch_parallelize import build_parallelize_model
@@ -21,37 +22,24 @@
2122

2223

2324
"""
24-
torchrun --nnodes=1 --nproc-per-node=8 --master-port=4321 tests/utils/test_trainer_saveload.py \
25-
--model.model_path Qwen/Qwen3-4B \
26-
--train.expert_parallel_size 1 \
27-
--train.global_batch_size 8 \
28-
--train.micro_batch_size 1 \
29-
--data.max_seq_len 128 \
30-
--data.train_path "dummy" \
31-
--train.output_dir ./test_trainer_saveload \
32-
--train.max_steps 5 \
33-
--train.rmpad false \
34-
--train.rmpad_with_pos_ids true \
35-
--train.data_parallel_mode "fsdp2" \
36-
--train.init_device "meta" \
37-
--train.ckpt_manager "dcp"
38-
39-
torchrun --nnodes=1 --nproc-per-node=8 --master-port=4321 tests/utils/test_trainer_saveload.py \
40-
--model.model_path /path/to/Qwen3-30B-A3B-Instruct-2507-merge \
25+
torchrun --nnodes=1 --nproc-per-node=8 --master-port=4321 tests/checkpoints/test_trainer_saveload.py \
26+
--model.config_path configs/model_configs/qwen/qwen3_moe_30a3b_4_layers.json \
27+
--model.weight_path None \
28+
--model.tokenizer_path /mnt/hdfs/models/Qwen3-30B-A3B \
4129
--model.moe_implementation fused \
4230
--model.attn_implementation flash_attention_2 \
43-
--train.expert_parallel_size 4 \
31+
--train.expert_parallel_size 8 \
4432
--train.global_batch_size 8 \
4533
--train.micro_batch_size 1 \
4634
--data.max_seq_len 128 \
4735
--data.train_path "dummy" \
48-
--train.output_dir ./test_trainer_saveload \
36+
--train.output_dir ./test_trainer_saveload_ep8 \
4937
--train.max_steps 5 \
5038
--train.rmpad false \
5139
--train.rmpad_with_pos_ids true \
5240
--train.data_parallel_mode "fsdp2" \
5341
--train.init_device "meta" \
54-
--train.ckpt_manager "dcp"
42+
--train.ckpt_manager "dcp" $@ 2>&1 | tee test_saveload_ep8.log
5543
"""
5644

5745
# To prevent DCP from complaining "too many open files"
@@ -143,8 +131,9 @@ def main():
143131
train_dataset = build_dummy_dataset(task_type="text", size=train_data_size, max_seq_len=args.data.max_seq_len)
144132

145133
args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size)
146-
train_dataloader = build_streaming_dataloader(
134+
train_dataloader = build_dataloader(
147135
dataset=train_dataset,
136+
dataloader_type="streaming",
148137
micro_batch_size=args.train.micro_batch_size,
149138
global_batch_size=args.train.global_batch_size,
150139
dataloader_batch_size=args.train.dataloader_batch_size,
@@ -356,5 +345,44 @@ def step_id(s):
356345
dist.destroy_process_group()
357346

358347

348+
def test_trainer_saveload_ep8():
349+
ep8_command = [
350+
"torchrun",
351+
"--nnodes=1",
352+
"--nproc_per_node=8",
353+
"--master_port=4321",
354+
"tests/utils/test_trainer_saveload.py",
355+
"tests/checkpoints/ep8.yaml",
356+
]
357+
ep8_result = subprocess.run(ep8_command, check=True)
358+
assert ep8_result.returncode == 0
359+
360+
361+
def test_trainer_saveload_ep4():
362+
ep4_command = [
363+
"torchrun",
364+
"--nnodes=1",
365+
"--nproc_per_node=8",
366+
"--master_port=4321",
367+
"tests/checkpoints/test_trainer_saveload.py",
368+
"tests/checkpoints/ep4.yaml",
369+
]
370+
ep4_result = subprocess.run(ep4_command, check=True)
371+
assert ep4_result.returncode == 0
372+
373+
374+
def test_trainer_saveload_no_ep():
375+
no_ep_command = [
376+
"torchrun",
377+
"--nnodes=1",
378+
"--nproc_per_node=8",
379+
"--master_port=4321",
380+
"tests/checkpoints/test_trainer_saveload.py",
381+
"tests/checkpoints/no_ep.yaml",
382+
]
383+
no_ep_result = subprocess.run(no_ep_command, check=True)
384+
assert no_ep_result.returncode == 0
385+
386+
359387
if __name__ == "__main__":
360388
main()

0 commit comments

Comments
 (0)