Skip to content

Commit ecfb7d6

Browse files
MaruyamaAyaoahzxl
andauthored
Shard ema (#15)
* shard ema among devices and modify checkpointing * shard ema among devices and modify checkpointing * rewrite sharding to support padding. * rewrite sharding to support padding. * add test for ema sharding * add test for ema sharding * Delete test.sh * update ema sharding in both scripts * update checkpoint * update test for ema sharding * update scripts --------- Co-authored-by: Xuanlei Zhao <[email protected]>
1 parent 0ee0d5b commit ecfb7d6

File tree

9 files changed

+248
-156
lines changed

9 files changed

+248
-156
lines changed

opendit/utils/ckpt_utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import functools
12
import json
23
import logging
4+
import operator
35
import os
46
from typing import Tuple
57

@@ -11,6 +13,8 @@
1113
from torch.optim import Optimizer
1214
from torch.optim.lr_scheduler import _LRScheduler
1315

16+
from opendit.utils.operation import model_sharding
17+
1418

1519
def load_json(file_path: str):
1620
with open(file_path, "r") as f:
@@ -22,6 +26,29 @@ def save_json(data, file_path: str):
2226
json.dump(data, f, indent=4)
2327

2428

29+
def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
30+
return tensor[: functools.reduce(operator.mul, original_shape)]
31+
32+
33+
def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
34+
global_rank = dist.get_rank()
35+
global_size = dist.get_world_size()
36+
for name, param in model.named_parameters():
37+
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
38+
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
39+
if int(global_rank) == 0:
40+
all_params = torch.cat(all_params)
41+
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
42+
dist.barrier()
43+
44+
45+
def record_model_param_shape(model: torch.nn.Module) -> dict:
46+
param_shape = {}
47+
for name, param in model.named_parameters():
48+
param_shape[name] = param.shape
49+
return param_shape
50+
51+
2552
def save(
2653
booster: Booster,
2754
model: nn.Module,
@@ -33,13 +60,19 @@ def save(
3360
batch_size: int,
3461
coordinator: DistCoordinator,
3562
save_dir: str,
63+
shape_dict: dict,
3664
):
3765
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
3866
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
3967

4068
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
4169
# ema is not boosted, so we don't need to use booster.save_model
42-
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
70+
model_gathering(ema, shape_dict)
71+
global_rank = dist.get_rank()
72+
if int(global_rank) == 0:
73+
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
74+
model_sharding(ema)
75+
4376
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
4477
if lr_scheduler is not None:
4578
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))

opendit/utils/operation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ def backward(ctx, *grad_output):
7171
return (return_grad, None, None, None)
7272

7373

74+
75+
def model_sharding(model: torch.nn.Module):
76+
global_rank = dist.get_rank()
77+
world_size = dist.get_world_size()
78+
for name, param in model.named_parameters():
79+
padding_size = (world_size - param.numel() % world_size) % world_size
80+
if padding_size > 0:
81+
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
82+
else:
83+
padding_param = param.data.view(-1)
84+
splited_params = padding_param.split(padding_param.numel() // world_size)
85+
splited_params = splited_params[global_rank]
86+
param.data = splited_params
87+
7488
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
7589
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
7690

opendit/utils/train_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from collections import OrderedDict
2+
3+
import torch
4+
import torch.distributed as dist
5+
6+
7+
def get_model_numel(model: torch.nn.Module) -> int:
8+
return sum(p.numel() for p in model.parameters())
9+
10+
11+
def format_numel_str(numel: int) -> str:
12+
B = 1024**3
13+
M = 1024**2
14+
K = 1024
15+
if numel >= B:
16+
return f"{numel / B:.2f} B"
17+
elif numel >= M:
18+
return f"{numel / M:.2f} M"
19+
elif numel >= K:
20+
return f"{numel / K:.2f} K"
21+
else:
22+
return f"{numel}"
23+
24+
25+
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
26+
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
27+
tensor.div_(dist.get_world_size())
28+
return tensor
29+
30+
31+
@torch.no_grad()
32+
def update_ema(
33+
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
34+
) -> None:
35+
"""
36+
Step the EMA model towards the current model.
37+
"""
38+
ema_params = OrderedDict(ema_model.named_parameters())
39+
model_params = OrderedDict(model.named_parameters())
40+
41+
for name, param in model_params.items():
42+
if name == "pos_embed":
43+
continue
44+
if not sharded:
45+
param_data = param.data
46+
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
47+
else:
48+
if param.data.dtype != torch.float32:
49+
param_id = id(param)
50+
master_param = optimizer._param_store.working_to_master_param[param_id]
51+
param_data = master_param.data
52+
else:
53+
param_data = param.data
54+
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
55+
56+
57+
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
58+
"""
59+
Set requires_grad flag for all parameters in a model.
60+
"""
61+
for p in model.parameters():
62+
p.requires_grad = flag

tests/test_checkpoint.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,23 @@
22
import shutil
33

44
import colossalai
5+
import pytest
56
import torch
67
import torch.distributed as dist
78
from colossalai.booster import Booster
89
from colossalai.booster.plugin import LowLevelZeroPlugin
910
from colossalai.nn.optimizer import HybridAdam
10-
from colossalai.testing import (
11-
check_state_dict_equal,
12-
clear_cache_before_run,
13-
parameterize,
14-
rerun_if_address_is_in_use,
15-
spawn,
16-
)
11+
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
1712
from colossalai.zero import LowLevelZeroOptimizer
1813

1914
from opendit.models.dit import DiT_S_2
2015

2116

22-
# stage 1 and 2 process the optimizer/mode the same way
23-
# only test 2 is fine
2417
@clear_cache_before_run()
25-
@parameterize("stage", [2])
26-
@parameterize("shard", [True, False])
27-
@parameterize("offload", [False, True])
28-
def _test_zero_checkpoint(stage: int, shard: bool, offload: bool):
29-
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
18+
def run_zero_checkpoint(stage: int, shard: bool, offload: bool):
19+
plugin = LowLevelZeroPlugin(precision="fp16", stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
3020
booster = Booster(plugin=plugin)
31-
model = DiT_S_2()
21+
model = DiT_S_2().half()
3222
criterion = lambda x: x.mean()
3323
optimizer = HybridAdam((model.parameters()), lr=0.001)
3424
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
@@ -56,7 +46,7 @@ def _test_zero_checkpoint(stage: int, shard: bool, offload: bool):
5646

5747
dist.barrier()
5848

59-
new_model = DiT_S_2()
49+
new_model = DiT_S_2().half()
6050
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
6151
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
6252

@@ -77,24 +67,25 @@ def _test_zero_checkpoint(stage: int, shard: bool, offload: bool):
7767

7868
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
7969
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
80-
torch.cuda.empty_cache()
8170

8271
if dist.get_rank() == 0:
8372
shutil.rmtree(tempdir)
8473
dist.barrier()
8574

8675

87-
def run_dist(rank, world_size, port):
76+
def run_dist(rank, world_size, port, stage: int, shard: bool, offload: bool):
8877
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
89-
_test_zero_checkpoint()
78+
run_zero_checkpoint(stage=stage, shard=shard, offload=offload)
9079
torch.cuda.empty_cache()
9180

9281

82+
@pytest.mark.parametrize("stage", [2])
83+
@pytest.mark.parametrize("shard", [True, False])
84+
@pytest.mark.parametrize("offload", [False, True])
9385
@rerun_if_address_is_in_use()
94-
@clear_cache_before_run()
95-
def test_zero_checkpoint():
96-
spawn(run_dist, 2)
86+
def test_zero_checkpoint(stage, shard, offload):
87+
spawn(run_dist, 2, stage=stage, shard=shard, offload=offload)
9788

9889

9990
if __name__ == "__main__":
100-
test_zero_checkpoint()
91+
test_zero_checkpoint(2, True, False)

tests/test_ema_sharding.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from copy import deepcopy
2+
3+
import colossalai
4+
import torch
5+
import torch.distributed as dist
6+
from colossalai.booster import Booster
7+
from colossalai.booster.plugin import LowLevelZeroPlugin
8+
from colossalai.nn.optimizer import HybridAdam
9+
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
10+
11+
from opendit.models.dit import DiT_S_2
12+
from opendit.utils.ckpt_utils import model_gathering, record_model_param_shape
13+
from opendit.utils.operation import model_sharding
14+
from opendit.utils.train_utils import update_ema
15+
16+
17+
@clear_cache_before_run()
18+
def run_ema_sharding():
19+
plugin = LowLevelZeroPlugin(precision="fp16", stage=2, max_norm=1.0, initial_scale=32)
20+
booster = Booster(plugin=plugin)
21+
model = DiT_S_2().cuda().half()
22+
23+
ema_sharding = deepcopy(model).eval()
24+
model_param_shape = record_model_param_shape(ema_sharding)
25+
model_sharding(ema_sharding)
26+
ema_no_sharding = deepcopy(model).eval()
27+
28+
criterion = lambda x: x.mean()
29+
optimizer = HybridAdam((model.parameters()), lr=0.001)
30+
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
31+
32+
x = torch.randn(2, 4, 32, 32).cuda().requires_grad_(True)
33+
y = torch.randint(0, 10, (2,)).cuda()
34+
t = torch.randint(0, 10, (2,)).cuda()
35+
output = model(x, y, t)
36+
loss = criterion(output)
37+
booster.backward(loss, optimizer)
38+
optimizer.step()
39+
40+
update_ema(ema_sharding, model.module, optimizer=optimizer, sharded=True, decay=0.5)
41+
update_ema(ema_no_sharding, model.module, optimizer=optimizer, sharded=False, decay=0.5)
42+
43+
# should be equal after update
44+
gather_ema_sharding = deepcopy(ema_sharding)
45+
model_gathering(gather_ema_sharding, model_param_shape)
46+
if dist.get_rank() == 0:
47+
for (gather_ema_sharding_name, gather_ema_sharding_param), (ema_no_sharding_name, ema_no_sharding_param) in zip(
48+
gather_ema_sharding.named_parameters(), ema_no_sharding.named_parameters()
49+
):
50+
assert gather_ema_sharding_name == ema_no_sharding_name
51+
if gather_ema_sharding_name == "pos_embed":
52+
continue
53+
assert torch.allclose(gather_ema_sharding_param, ema_no_sharding_param)
54+
dist.barrier()
55+
56+
# should be same after sharding again
57+
if dist.get_rank() == 0:
58+
model_sharding(gather_ema_sharding)
59+
for (gather_ema_sharding_name, gather_ema_sharding_param), (ema_sharding_name, ema_sharding_param) in zip(
60+
gather_ema_sharding.named_parameters(), ema_sharding.named_parameters()
61+
):
62+
assert gather_ema_sharding_name == ema_sharding_name
63+
if gather_ema_sharding_name == "pos_embed":
64+
continue
65+
assert torch.allclose(gather_ema_sharding_param, ema_sharding_param)
66+
dist.barrier()
67+
68+
69+
def run_dist(rank, world_size, port):
70+
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
71+
run_ema_sharding()
72+
torch.cuda.empty_cache()
73+
74+
75+
@rerun_if_address_is_in_use()
76+
def test_ema_sharding():
77+
spawn(run_dist, 2)
78+
79+
80+
if __name__ == "__main__":
81+
test_ema_sharding()

0 commit comments

Comments
 (0)