Skip to content

Commit 0ee0d5b

Browse files
authored
resume img pipeline and add tests (#22)
* init * Update dependencies and installation instructions * rename video train * update script * Update train.sh to train_video.sh and add train_img.sh * Added matplotlib and flash_attn to requirements.txt * Remove data-path argument from train_img.sh * update args * update test * Add test for new model implementation * Add weight initialization function and update model comparison in test_model() * add ckpt * update clip test * Commented out unnecessary flags in train_img.sh script * fix bugs * Add transformers library to requirements.txt * update
1 parent 268978a commit 0ee0d5b

File tree

12 files changed

+1014
-62
lines changed

12 files changed

+1014
-62
lines changed

README.md

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,46 @@
1-
## Usage
1+
# OpenDiT
2+
### Install ColossalAI
3+
```
4+
git clone https://github.com/hpcaitech/ColossalAI.git
5+
cd ColossalAI
6+
git checkout adae123df3badfb15d044bd416f0cf29f250bc86
7+
pip install -e .
8+
```
29

10+
### Install OpenDiT
311
```
412
# Prerequisite
513
cd OpenDiT
614
pip install -e .
7-
8-
# Train
9-
bash preprocess.sh
10-
bash train.sh
11-
12-
# Infer
15+
```
16+
## Image Pipeline
17+
```
18+
# train
19+
bash train_img.sh
20+
# inference
1321
bash sample.sh
1422
```
15-
16-
## Install ColossalAI
23+
## Video Pipeline
1724
```
18-
git clone https://github.com/hpcaitech/ColossalAI.git
19-
cd ColossalAI
20-
git checkout adae123df3badfb15d044bd416f0cf29f250bc86
21-
pip install -e .
25+
# train
26+
bash preprocess.sh
27+
bash train_video.sh
2228
```
29+
## Install kernels to speed up
30+
```
31+
# triton for modulate kernel
32+
pip install triton
33+
34+
# flash attention
35+
pip install flash-attn
36+
37+
# apex layernorm
38+
git clone https://github.com/NVIDIA/apex.git
39+
cd apex
40+
git checkout 741bdf50825a97664db08574981962d66436d16a
41+
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext" --global-option="--cpp_ext"
42+
```
43+
2344

2445
## Scalable Diffusion Models with Transformers (DiT)<br><sub>Official PyTorch Implementation</sub>
2546

opendit/models/dit.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
from opendit.models.clip import TextEmbedder
2424
from opendit.utils.operation import all_to_all_comm, gather_forward_split_backward
2525

26-
ULYSSES = False
27-
FLASH_ATTN = False
28-
SP_SIZE = 2
29-
LAYERNORM_KERNEL = False
30-
MODULATE_KERNEL = False
31-
3226

3327
def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
3428
if use_kernel:
@@ -45,16 +39,17 @@ def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kerne
4539
def modulate(norm_func, x, shift, scale, use_kernel=False):
4640
# Suppose x is (N, T, D), shift is (N, D), scale is (N, D)
4741
dtype = x.dtype
48-
x = norm_func(x.to(torch.float32))
42+
x, shift, scale = x.to(torch.float32), shift.to(torch.float32), scale.to(torch.float32)
43+
x = norm_func(x)
4944
if use_kernel:
5045
try:
5146
from opendit.kernels.fused_modulate import fused_modulate
5247

53-
x = fused_modulate(x, scale.to(torch.float32), shift.to(torch.float32))
48+
x = fused_modulate(x, scale, shift)
5449
except ImportError:
5550
raise RuntimeError("FusedModulate kernel not available. Please install triton.")
5651
else:
57-
x = x * (scale.to(torch.float32).unsqueeze(1) + 1) + shift.to(torch.float32).unsqueeze(1)
52+
x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
5853
x = x.to(dtype)
5954

6055
return x
@@ -156,8 +151,8 @@ def __init__(
156151
attn_drop: float = 0.0,
157152
proj_drop: float = 0.0,
158153
norm_layer: nn.Module = nn.LayerNorm,
159-
use_flash_attn: bool = False,
160-
enable_sequence_parallelism: bool = False,
154+
enable_flashattn: bool = False,
155+
sequence_parallel_size: int = 1,
161156
) -> None:
162157
super().__init__()
163158
assert dim % num_heads == 0, "dim should be divisible by num_heads"
@@ -172,16 +167,20 @@ def __init__(
172167
self.attn_drop = nn.Dropout(attn_drop)
173168
self.proj = nn.Linear(dim, dim)
174169
self.proj_drop = nn.Dropout(proj_drop)
175-
self.use_flash_attn = use_flash_attn
176-
self.enable_sequence_parallelism = enable_sequence_parallelism
170+
self.enable_flashattn = enable_flashattn
171+
# TODO: support sequence_parallel_size > 2
172+
assert sequence_parallel_size in [1, 2], "sequence_parallel_size is only supported for 1 or 2"
173+
self.sequence_parallel_size = sequence_parallel_size
177174

178175
def forward(self, x: torch.Tensor) -> torch.Tensor:
179176
B, N, C = x.shape
180177
qkv = self.qkv(x) # (B, N, C), N here is N_total // SP_SIZE
181178
# Todo: Change num_heads in somewhere else for a better code style
182-
num_heads = self.num_heads if not self.enable_sequence_parallelism else self.num_heads // SP_SIZE
179+
num_heads = (
180+
self.num_heads if self.sequence_parallel_size == 1 else self.num_heads // self.sequence_parallel_size
181+
)
183182

184-
if self.enable_sequence_parallelism:
183+
if self.sequence_parallel_size > 1:
185184
q, k, v = qkv.split(self.head_dim * self.num_heads, dim=-1)
186185
# q = q.reshape(1, -1, self.head_dim * self.num_heads)
187186
# k = k.reshape(1, -1, self.head_dim * self.num_heads)
@@ -191,9 +190,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
191190
k = all_to_all_comm(k, None)
192191
v = all_to_all_comm(v, None)
193192

194-
q = q.reshape(B, N * SP_SIZE, num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
195-
k = k.reshape(B, N * SP_SIZE, num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
196-
v = v.reshape(B, N * SP_SIZE, num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
193+
q = q.reshape(B, N * self.sequence_parallel_size, num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
194+
k = k.reshape(B, N * self.sequence_parallel_size, num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
195+
v = v.reshape(B, N * self.sequence_parallel_size, num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
197196

198197
else:
199198
# Todo: chunked flash attention
@@ -204,7 +203,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
204203
# .permute(2, 3, 0, 1, 4)
205204
# .reshape(3, B * num_heads, 1, N, self.head_dim)
206205
# )
207-
if self.use_flash_attn:
206+
if self.enable_flashattn:
208207
# [3, B, num_heads, N, head_dim] => [B, N, num_heads, head_dim] * 3
209208
qkv = qkv.reshape(B, N, 3, num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
210209
else:
@@ -213,7 +212,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
213212
q, k, v = qkv.unbind(0)
214213
q, k = self.q_norm(q), self.k_norm(k)
215214

216-
if self.use_flash_attn:
215+
if self.enable_flashattn:
217216
from flash_attn import flash_attn_func
218217

219218
# Todo: chunked flash attention
@@ -258,10 +257,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
258257
x = attn @ v
259258

260259
x_output_shape = (
261-
(B, N, C) if not self.enable_sequence_parallelism else (B, N * SP_SIZE, num_heads * self.head_dim)
260+
(B, N, C)
261+
if self.sequence_parallel_size == 1
262+
else (B, N * self.sequence_parallel_size, num_heads * self.head_dim)
262263
)
263264
x = x.transpose(1, 2).reshape(x_output_shape)
264-
if self.enable_sequence_parallelism:
265+
if self.sequence_parallel_size > 1:
265266
# Todo: Use all_to_all_single for x
266267
# x = x.reshape(1, -1, num_heads * self.head_dim)
267268
x = all_to_all_comm(x, None, scatter_dim=1, gather_dim=2)
@@ -281,33 +282,37 @@ def __init__(
281282
hidden_size,
282283
num_heads,
283284
mlp_ratio=4.0,
284-
flash_attn=False,
285-
sequence_parallel=False,
286-
layernorm_kernel=False,
287-
modulate_kernel=False,
285+
enable_flashattn=False,
286+
sequence_parallel_size=False,
287+
enable_layernorm_kernel=False,
288+
enable_modulate_kernel=False,
288289
**block_kwargs,
289290
):
290291
super().__init__()
291-
self.modulate_kernel = modulate_kernel
292-
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=layernorm_kernel)
292+
self.enable_modulate_kernel = enable_modulate_kernel
293+
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
293294
self.attn = DistAttention(
294295
hidden_size,
295296
num_heads=num_heads,
296297
qkv_bias=True,
297-
use_flash_attn=flash_attn,
298-
enable_sequence_parallelism=sequence_parallel,
298+
enable_flashattn=enable_flashattn,
299+
sequence_parallel_size=sequence_parallel_size,
299300
**block_kwargs,
300301
)
301-
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=layernorm_kernel)
302+
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
302303
mlp_hidden_dim = int(hidden_size * mlp_ratio)
303304
approx_gelu = lambda: nn.GELU(approximate="tanh")
304305
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
305306
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
306307

307308
def forward(self, x, c):
308309
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
309-
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa, self.modulate_kernel))
310-
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp, self.modulate_kernel))
310+
x = x + gate_msa.unsqueeze(1) * self.attn(
311+
modulate(self.norm1, x, shift_msa, scale_msa, self.enable_modulate_kernel)
312+
)
313+
x = x + gate_mlp.unsqueeze(1) * self.mlp(
314+
modulate(self.norm2, x, shift_mlp, scale_mlp, self.enable_modulate_kernel)
315+
)
311316
return x
312317

313318

@@ -347,17 +352,18 @@ def __init__(
347352
class_dropout_prob=0.1,
348353
num_classes=1000,
349354
learn_sigma=True,
350-
flash_attn=FLASH_ATTN,
351-
sequence_parallel=ULYSSES,
352-
layernorm_kernel=LAYERNORM_KERNEL,
353-
modulate_kernel=MODULATE_KERNEL,
355+
enable_flashattn=False,
356+
enable_layernorm_kernel=False,
357+
enable_modulate_kernel=False,
358+
sequence_parallel_size=1,
354359
):
355360
super().__init__()
356361
self.learn_sigma = learn_sigma
357362
self.in_channels = in_channels
358363
self.out_channels = in_channels * 2 if learn_sigma else in_channels
359364
self.patch_size = patch_size
360365
self.num_heads = num_heads
366+
self.sequence_parallel_size = sequence_parallel_size
361367

362368
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
363369
self.t_embedder = TimestepEmbedder(hidden_size)
@@ -378,10 +384,10 @@ def __init__(
378384
hidden_size,
379385
num_heads,
380386
mlp_ratio=mlp_ratio,
381-
flash_attn=flash_attn,
382-
sequence_parallel=sequence_parallel,
383-
modulate_kernel=modulate_kernel,
384-
layernorm_kernel=layernorm_kernel,
387+
enable_flashattn=enable_flashattn,
388+
sequence_parallel_size=sequence_parallel_size,
389+
enable_modulate_kernel=enable_modulate_kernel,
390+
enable_layernorm_kernel=enable_layernorm_kernel,
385391
)
386392
for _ in range(depth)
387393
]
@@ -471,16 +477,16 @@ def forward(self, x, t, y):
471477
c = t + y # (N, D)
472478

473479
# Chunk x on sequence dimension to sp group
474-
if ULYSSES:
475-
x = x.chunk(SP_SIZE, dim=1)[dist.get_rank()]
480+
if self.sequence_parallel_size > 1:
481+
x = x.chunk(self.sequence_parallel_size, dim=1)[dist.get_rank()]
476482

477483
for block in self.blocks:
478484
if self.gradient_checkpointing:
479485
x = torch.utils.checkpoint.checkpoint(self.create_custom_forward(block), x, c)
480486
else:
481487
x = block(x, c) # (N, T, D)
482488

483-
if ULYSSES:
489+
if self.sequence_parallel_size > 1:
484490
x = gather_forward_split_backward(x, dim=1, process_group=None)
485491

486492
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)

requirements.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@ pytorch_lightning
1818
h5py
1919
gdown
2020
scikit-video
21-
flash_attn
22-
pyav
21+
pyav
22+
tensorboard
23+
timm
24+
matplotlib
25+
accelerate
26+
diffusers
27+
transformers
28+
flash_attn==2.0.5

tests/test_checkpoint.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import shutil
3+
4+
import colossalai
5+
import torch
6+
import torch.distributed as dist
7+
from colossalai.booster import Booster
8+
from colossalai.booster.plugin import LowLevelZeroPlugin
9+
from colossalai.nn.optimizer import HybridAdam
10+
from colossalai.testing import (
11+
check_state_dict_equal,
12+
clear_cache_before_run,
13+
parameterize,
14+
rerun_if_address_is_in_use,
15+
spawn,
16+
)
17+
from colossalai.zero import LowLevelZeroOptimizer
18+
19+
from opendit.models.dit import DiT_S_2
20+
21+
22+
# stage 1 and 2 process the optimizer/mode the same way
23+
# only test 2 is fine
24+
@clear_cache_before_run()
25+
@parameterize("stage", [2])
26+
@parameterize("shard", [True, False])
27+
@parameterize("offload", [False, True])
28+
def _test_zero_checkpoint(stage: int, shard: bool, offload: bool):
29+
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
30+
booster = Booster(plugin=plugin)
31+
model = DiT_S_2()
32+
criterion = lambda x: x.mean()
33+
optimizer = HybridAdam((model.parameters()), lr=0.001)
34+
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
35+
36+
x = torch.randn(2, 4, 32, 32).cuda().requires_grad_(True)
37+
y = torch.randint(0, 10, (2,)).cuda()
38+
t = torch.randint(0, 10, (2,)).cuda()
39+
output = model(x, y, t)
40+
loss = criterion(output)
41+
booster.backward(loss, optimizer)
42+
optimizer.step()
43+
44+
tempdir = "./tempdir"
45+
if dist.get_rank() == 0:
46+
if os.path.exists(tempdir):
47+
shutil.rmtree(tempdir)
48+
os.makedirs(tempdir)
49+
dist.barrier()
50+
51+
model_ckpt_path = f"{tempdir}/model"
52+
optimizer_ckpt_path = f"{tempdir}/optimizer"
53+
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
54+
booster.save_model(model, model_ckpt_path, shard=shard)
55+
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
56+
57+
dist.barrier()
58+
59+
new_model = DiT_S_2()
60+
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
61+
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
62+
63+
booster.load_model(new_model, model_ckpt_path)
64+
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
65+
# check master weight
66+
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
67+
working_param_id_set = set(id(p) for p in new_model.parameters())
68+
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
69+
assert p_id in working_param_id_set
70+
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
71+
padding = new_optimizer._param_store.get_param_padding_size(working_param)
72+
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
73+
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
74+
assert torch.equal(
75+
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
76+
)
77+
78+
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
79+
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
80+
torch.cuda.empty_cache()
81+
82+
if dist.get_rank() == 0:
83+
shutil.rmtree(tempdir)
84+
dist.barrier()
85+
86+
87+
def run_dist(rank, world_size, port):
88+
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
89+
_test_zero_checkpoint()
90+
torch.cuda.empty_cache()
91+
92+
93+
@rerun_if_address_is_in_use()
94+
@clear_cache_before_run()
95+
def test_zero_checkpoint():
96+
spawn(run_dist, 2)
97+
98+
99+
if __name__ == "__main__":
100+
test_zero_checkpoint()

0 commit comments

Comments
 (0)