Skip to content

Commit f1ce7e2

Browse files
oahzxlMaruyamaAya
andauthored
update test (#23)
* shard ema among devices and modify checkpointing * shard ema among devices and modify checkpointing * rewrite sharding to support padding. * rewrite sharding to support padding. * add test for ema sharding * add test for ema sharding * Delete test.sh * update ema sharding in both scripts * update checkpoint * update test for ema sharding * update scripts * fix tests * update test --------- Co-authored-by: Maruyama_Aya <[email protected]> Co-authored-by: Ziming Liu <[email protected]>
1 parent ecfb7d6 commit f1ce7e2

File tree

4 files changed

+33
-26
lines changed

4 files changed

+33
-26
lines changed

tests/test_checkpoint.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88
from colossalai.booster import Booster
99
from colossalai.booster.plugin import LowLevelZeroPlugin
1010
from colossalai.nn.optimizer import HybridAdam
11-
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
11+
from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn
1212
from colossalai.zero import LowLevelZeroOptimizer
1313

14-
from opendit.models.dit import DiT_S_2
14+
from opendit.models.dit import DiT
1515

1616

17-
@clear_cache_before_run()
1817
def run_zero_checkpoint(stage: int, shard: bool, offload: bool):
1918
plugin = LowLevelZeroPlugin(precision="fp16", stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
2019
booster = Booster(plugin=plugin)
21-
model = DiT_S_2().half()
20+
model = DiT(depth=2, hidden_size=64, patch_size=2, num_heads=4).half()
2221
criterion = lambda x: x.mean()
2322
optimizer = HybridAdam((model.parameters()), lr=0.001)
2423
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
@@ -40,13 +39,12 @@ def run_zero_checkpoint(stage: int, shard: bool, offload: bool):
4039

4140
model_ckpt_path = f"{tempdir}/model"
4241
optimizer_ckpt_path = f"{tempdir}/optimizer"
43-
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
4442
booster.save_model(model, model_ckpt_path, shard=shard)
4543
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
4644

4745
dist.barrier()
4846

49-
new_model = DiT_S_2().half()
47+
new_model = DiT(depth=2, hidden_size=64, patch_size=2, num_heads=4).half()
5048
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
5149
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
5250

@@ -67,6 +65,7 @@ def run_zero_checkpoint(stage: int, shard: bool, offload: bool):
6765

6866
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
6967
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
68+
dist.barrier()
7069

7170
if dist.get_rank() == 0:
7271
shutil.rmtree(tempdir)
@@ -76,7 +75,6 @@ def run_zero_checkpoint(stage: int, shard: bool, offload: bool):
7675
def run_dist(rank, world_size, port, stage: int, shard: bool, offload: bool):
7776
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
7877
run_zero_checkpoint(stage=stage, shard=shard, offload=offload)
79-
torch.cuda.empty_cache()
8078

8179

8280
@pytest.mark.parametrize("stage", [2])

tests/test_ema_sharding.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from copy import deepcopy
23

34
import colossalai
@@ -8,22 +9,31 @@
89
from colossalai.nn.optimizer import HybridAdam
910
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
1011

11-
from opendit.models.dit import DiT_S_2
12+
from opendit.models.dit import DiT
1213
from opendit.utils.ckpt_utils import model_gathering, record_model_param_shape
1314
from opendit.utils.operation import model_sharding
1415
from opendit.utils.train_utils import update_ema
1516

1617

18+
def assert_params_equal(model1, model2):
19+
for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()):
20+
assert name1 == name2
21+
if name1 == "pos_embed":
22+
continue
23+
assert torch.allclose(param1, param2)
24+
25+
1726
@clear_cache_before_run()
1827
def run_ema_sharding():
1928
plugin = LowLevelZeroPlugin(precision="fp16", stage=2, max_norm=1.0, initial_scale=32)
2029
booster = Booster(plugin=plugin)
21-
model = DiT_S_2().cuda().half()
30+
model = DiT(depth=2, hidden_size=64, patch_size=2, num_heads=4).cuda().half()
2231

2332
ema_sharding = deepcopy(model).eval()
2433
model_param_shape = record_model_param_shape(ema_sharding)
2534
model_sharding(ema_sharding)
2635
ema_no_sharding = deepcopy(model).eval()
36+
ema_to_read = deepcopy(model).eval()
2737

2838
criterion = lambda x: x.mean()
2939
optimizer = HybridAdam((model.parameters()), lr=0.001)
@@ -44,25 +54,21 @@ def run_ema_sharding():
4454
gather_ema_sharding = deepcopy(ema_sharding)
4555
model_gathering(gather_ema_sharding, model_param_shape)
4656
if dist.get_rank() == 0:
47-
for (gather_ema_sharding_name, gather_ema_sharding_param), (ema_no_sharding_name, ema_no_sharding_param) in zip(
48-
gather_ema_sharding.named_parameters(), ema_no_sharding.named_parameters()
49-
):
50-
assert gather_ema_sharding_name == ema_no_sharding_name
51-
if gather_ema_sharding_name == "pos_embed":
52-
continue
53-
assert torch.allclose(gather_ema_sharding_param, ema_no_sharding_param)
57+
assert_params_equal(gather_ema_sharding, ema_no_sharding)
58+
dist.barrier()
59+
60+
# should be same after read again
61+
if dist.get_rank() == 0:
62+
torch.save(gather_ema_sharding.state_dict(), "tmp.pth")
63+
ema_to_read.load_state_dict(torch.load("tmp.pth"))
64+
assert_params_equal(gather_ema_sharding, ema_to_read)
65+
os.remove("tmp.pth")
5466
dist.barrier()
5567

5668
# should be same after sharding again
5769
if dist.get_rank() == 0:
5870
model_sharding(gather_ema_sharding)
59-
for (gather_ema_sharding_name, gather_ema_sharding_param), (ema_sharding_name, ema_sharding_param) in zip(
60-
gather_ema_sharding.named_parameters(), ema_sharding.named_parameters()
61-
):
62-
assert gather_ema_sharding_name == ema_sharding_name
63-
if gather_ema_sharding_name == "pos_embed":
64-
continue
65-
assert torch.allclose(gather_ema_sharding_param, ema_sharding_param)
71+
assert_params_equal(gather_ema_sharding, ema_sharding)
6672
dist.barrier()
6773

6874

tests/test_flash_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
22

33
import colossalai
4+
import flash_attn
5+
import pytest
46
import torch
57
import torch.nn as nn
68
import torch.nn.functional as F
@@ -120,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
120122
return x
121123

122124

123-
def flash_attn(seq_len, hidden_dim, head_num, batch_size):
125+
def _run_flash_attn(seq_len, hidden_dim, head_num, batch_size):
124126
seq_len = seq_len
125127
hidden_dim = hidden_dim
126128
head_num = head_num
@@ -231,14 +233,15 @@ def flash_attn(seq_len, hidden_dim, head_num, batch_size):
231233
@parameterize("head_num", [16])
232234
@parameterize("batch_size", [2])
233235
def run_flash_attn(seq_len, hidden_dim, head_num, batch_size):
234-
flash_attn(seq_len, hidden_dim, head_num, batch_size)
236+
_run_flash_attn(seq_len, hidden_dim, head_num, batch_size)
235237

236238

237239
def check_all2all_attn(rank, world_size, port):
238240
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
239241
run_flash_attn()
240242

241243

244+
@pytest.mark.skipif(flash_attn.__version__ < "2.4.1", reason="requires flashattn 2.4.1 or higher")
242245
@rerun_if_address_is_in_use()
243246
def test_flash_attn():
244247
spawn(check_all2all_attn, nprocs=WORKERS)

tests/test_sequence_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from opendit.utils.operation import all_to_all_comm
1515

16-
WORKERS = 4
16+
WORKERS = 2
1717

1818

1919
class DistAttention(nn.Module):

0 commit comments

Comments
 (0)