Skip to content

Commit bab4016

Browse files
oahzxlKKZ20
andauthored
support latte and reorganize dir (#77)
* refactor dir and add latte * Add copyright and license information to source files * Add VDiT-XL/2x2x2 video model and Latte-XL/2x2x2 model * Update model and load checkpoint from sharded state dict * Remove unused import statement * polish * Update class labels for video conditioning * Add colossalai and LowLevelZeroPlugin imports, update MASTER_PORT * move dir * update train * Update model configuration in sample_video.sh and train_video.sh * Update OpenDiT README with latest news * Update module imports and model types * update sequence parallel test * update readme and train example * fix bugs on num_heads * Disable modulate kernel optimization due to NaN issues * Update model options in training and sampling scripts --------- Co-authored-by: KKZ20 <[email protected]>
1 parent 0a3843b commit bab4016

31 files changed

+1068
-794
lines changed

README.md

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
<p align="center"><a href="https://github.com/NUS-HPC-AI-Lab/OpenDiT">[Homepage]</a> | <a href="https://discord.gg/yXF4n8Et">[Discord]</a> | <a href="./figure/wechat.jpg">[WeChat]</a> | <a href="https://twitter.com/YangYou1991/status/1762447718105170185">[Twitter]</a> | <a href="https://zhuanlan.zhihu.com/p/684457582">[Zhihu]</a> | <a href="https://mp.weixin.qq.com/s/IBb9vlo8hfYKrj9ztxkhjg">[Media]</a></p>
77
</p>
88

9+
### Latest News 🔥
10+
11+
* [2024/03/01] Support DiT-based Latte for text-to-video generation.
12+
* [2024/02/27] Officially release OpenDiT: An Easy, Fast and Memory-Efficent System for DiT Training and Inference.
13+
914
# About
1015

11-
OpenDiT is an open-source project that provides a high-performance implementation of Diffusion Transformer(DiT) powered by Colossal-AI, specifically designed to enhance the efficiency of training and inference for DiT applications, including text-to-video generation and text-to-image generation.
16+
OpenDiT is an open-source project that provides a high-performance implementation of Diffusion Transformer (DiT) powered by Colossal-AI, specifically designed to enhance the efficiency of training and inference for DiT applications, including text-to-video generation and text-to-image generation.
1217

1318
OpenDiT boasts the performance by the following techniques:
1419

@@ -87,26 +92,30 @@ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation -
8792

8893
### Image
8994

90-
<b>Training.</b> You can train the DiT model by executing the following command:
95+
<b>Training.</b> You can train the DiT model on CIFAR10 by executing the following command:
9196

9297
```shell
9398
# Use script
9499
bash train_img.sh
95100
# Use command line
96101
torchrun --standalone --nproc_per_node=2 train.py \
97102
--model DiT-XL/2 \
98-
--batch_size 2
103+
--batch_size 2 \
104+
--num_classes 10
99105
```
100106

101107
We disable all speedup methods by default. Here are details of some key arguments for training:
102108
- `--nproc_per_node`: The GPU number you want to use for the current node.
103109
- `--plugin`: The booster plugin used by ColossalAI, `zero2` and `ddp` are supported. The default value is `zero2`. Recommend to enable `zero2`.
104110
- `--mixed_precision`: The data type for mixed precision training. The default value is `fp16`.
105111
- `--grad_checkpoint`: Whether enable the gradient checkpointing. This saves the memory cost during training process. The default value is `False`. Recommend to disable it when memory is enough.
106-
- `--enable_modulate_kernel`: Whether enable the modulate kernel optimization. This speeds up the training process. The default value is `False`. Recommend to enable it for GPU < H100.
107112
- `--enable_layernorm_kernel`: Whether enable the layernorm kernel optimization. This speeds up the training process. The default value is `False`. Recommend to enable it.
108113
- `--enable_flashattn`: Whether enable the FlashAttention. This speeds up the training process. The default value is `False`. Recommend to enable.
114+
- `--enable_modulate_kernel`: Whether enable the modulate kernel optimization. This speeds up the training process. The default value is `False`. This kernel will cause NaN under some circumstances. So we recommend to disable it for now.
109115
- `--sequence_parallel_size`: The sequence parallelism size. Will enable sequence parallelism when setting a value > 1. The default value is 1. Recommend to disable it if memory is enough.
116+
- `--load`: Load previous saved checkpoint dir and continue training.
117+
- `--num_classes`: Label class number. Only used for label-to-image generation.
118+
110119

111120
For more details on the configuration of the training process, please visit our code.
112121

@@ -137,14 +146,17 @@ python sample.py --model DiT-XL/2 --image_size 256 --ckpt ./model.pt
137146
```
138147

139148
### Video
140-
<b>Training.</b> Our video training pipeline is a faithful implementation, and we encourage you to explore your own strategies using OpenDiT. You can train the video DiT model by executing the following command:
149+
<b>Training.</b> We current support `VDiT` and `Latte` for video generation. VDiT adopts DiT structure and use video as inputs data. Latte further use more efficient spatial & temporal blocks based on VDiT (not exactly align with origin [Latte](https://github.com/Vchitect/Latte)).
150+
151+
Our video training pipeline is a faithful implementation, and we encourage you to explore your own strategies using OpenDiT. You can train the video DiT model by executing the following command:
141152

142153
```shell
143154
# train with scipt
144155
bash train_video.sh
145156
# train with command line
157+
# model can also be Latte-XL/1x2x2
146158
torchrun --standalone --nproc_per_node=2 train.py \
147-
--model vDiT-XL/222 \
159+
--model VDiT-XL/1x2x2 \
148160
--use_video \
149161
--data_path ./videos/demo.csv \
150162
--batch_size 1 \
@@ -166,15 +178,18 @@ This script shares the same speedup methods as we have shown in the image traini
166178
# Use script
167179
bash sample_video.sh
168180
# Use command line
181+
# model can also be Latte-XL/1x2x2
169182
python sample.py \
170-
--model vDiT-XL/222 \
183+
--model VDiT-XL/1x2x2 \
171184
--use_video \
172185
--ckpt ckpt_path \
173186
--num_frames 16 \
174187
--image_size 256 \
175188
--frame_interval 3
176189
```
177190

191+
Inference tips: 1) EMA model requires quite long time to converge and produce meaningful results. So you can sample base model (`--ckpt /epochXX-global_stepXX/model`) instead of ema model (`--ckpt /epochXX-global_stepXX/ema.pt`) to check your training process. 2) Modify the text condition in `sample.py` which aligns with your datasets helps to produce better results in the early stage of training.
192+
178193
## FastSeq
179194

180195
![fastseq_overview](./figure/fastseq_overview.png)
@@ -210,7 +225,8 @@ torchrun --standalone --nproc_per_node=8 train.py \
210225
--batch_size 180 \
211226
--enable_layernorm_kernel \
212227
--enable_flashattn \
213-
--mixed_precision fp16
228+
--mixed_precision fp16 \
229+
--num_classes 1000
214230
```
215231

216232

File renamed without changes.
File renamed without changes.
File renamed without changes.

opendit/embed/label_emb.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Modified from Meta DiT
2+
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# --------------------------------------------------------
6+
# References:
7+
# DiT: https://github.com/facebookresearch/DiT/tree/main
8+
# GLIDE: https://github.com/openai/glide-text2im
9+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10+
# --------------------------------------------------------
11+
12+
13+
import torch
14+
from torch import nn
15+
16+
17+
class LabelEmbedder(nn.Module):
18+
"""
19+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
20+
"""
21+
22+
def __init__(self, num_classes, hidden_size, dropout_prob):
23+
super().__init__()
24+
use_cfg_embedding = dropout_prob > 0
25+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
26+
self.num_classes = num_classes
27+
self.dropout_prob = dropout_prob
28+
29+
def token_drop(self, labels, force_drop_ids=None):
30+
"""
31+
Drops labels to enable classifier-free guidance.
32+
"""
33+
if force_drop_ids is None:
34+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
35+
else:
36+
drop_ids = force_drop_ids == 1
37+
labels = torch.where(drop_ids, self.num_classes, labels)
38+
return labels
39+
40+
def forward(self, labels, train, force_drop_ids=None):
41+
use_dropout = self.dropout_prob > 0
42+
if (train and use_dropout) or (force_drop_ids is not None):
43+
labels = self.token_drop(labels, force_drop_ids)
44+
embeddings = self.embedding_table(labels)
45+
return embeddings

opendit/embed/patch_emb.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch.nn.functional as F
8+
from torch import nn
9+
10+
11+
class PatchEmbed3D(nn.Module):
12+
"""Video to Patch Embedding.
13+
14+
Args:
15+
patch_size (int): Patch token size. Default: (2,4,4).
16+
in_chans (int): Number of input video channels. Default: 3.
17+
embed_dim (int): Number of linear projection output channels. Default: 96.
18+
norm_layer (nn.Module, optional): Normalization layer. Default: None
19+
"""
20+
21+
def __init__(
22+
self,
23+
patch_size=(2, 4, 4),
24+
in_chans=3,
25+
embed_dim=96,
26+
norm_layer=None,
27+
flatten=True,
28+
):
29+
super().__init__()
30+
self.patch_size = patch_size
31+
self.flatten = flatten
32+
33+
self.in_chans = in_chans
34+
self.embed_dim = embed_dim
35+
36+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
37+
if norm_layer is not None:
38+
self.norm = norm_layer(embed_dim)
39+
else:
40+
self.norm = None
41+
42+
def forward(self, x):
43+
"""Forward function."""
44+
# padding
45+
_, _, D, H, W = x.size()
46+
if W % self.patch_size[2] != 0:
47+
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
48+
if H % self.patch_size[1] != 0:
49+
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
50+
if D % self.patch_size[0] != 0:
51+
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
52+
53+
x = self.proj(x) # (B C T H W)
54+
if self.norm is not None:
55+
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
56+
x = x.flatten(2).transpose(1, 2)
57+
x = self.norm(x)
58+
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
59+
if self.flatten:
60+
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
61+
return x

opendit/embed/pos_emb.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import numpy as np
8+
9+
10+
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
11+
"""
12+
grid_size: int of the grid height and width
13+
t_size: int of the temporal size
14+
return:
15+
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
16+
"""
17+
assert embed_dim % 4 == 0
18+
embed_dim_spatial = embed_dim // 4 * 3
19+
embed_dim_temporal = embed_dim // 4
20+
21+
# spatial
22+
grid_h = np.arange(grid_size, dtype=np.float32)
23+
grid_w = np.arange(grid_size, dtype=np.float32)
24+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
25+
grid = np.stack(grid, axis=0)
26+
27+
grid = grid.reshape([2, 1, grid_size, grid_size])
28+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
29+
30+
# temporal
31+
grid_t = np.arange(t_size, dtype=np.float32)
32+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
33+
34+
# concate: [T, H, W] order
35+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
36+
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_size**2, axis=1) # [T, H*W, D // 4]
37+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
38+
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
39+
40+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
41+
pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
42+
43+
if cls_token:
44+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
45+
return pos_embed
46+
47+
48+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
49+
"""
50+
grid_size: int of the grid height and width
51+
return:
52+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
53+
"""
54+
grid_h = np.arange(grid_size, dtype=np.float32)
55+
grid_w = np.arange(grid_size, dtype=np.float32)
56+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
57+
grid = np.stack(grid, axis=0)
58+
59+
grid = grid.reshape([2, 1, grid_size, grid_size])
60+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
61+
if cls_token and extra_tokens > 0:
62+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
63+
return pos_embed
64+
65+
66+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
67+
assert embed_dim % 2 == 0
68+
69+
# use half of dimensions to encode grid_h
70+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
71+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
72+
73+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
74+
return emb
75+
76+
77+
def get_1d_sincos_pos_embed(embed_dim, length):
78+
pos = np.arange(0, length)[..., None]
79+
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
80+
81+
82+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
83+
"""
84+
embed_dim: output dimension for each position
85+
pos: a list of positions to be encoded: size (M,)
86+
out: (M, D)
87+
"""
88+
assert embed_dim % 2 == 0
89+
omega = np.arange(embed_dim // 2, dtype=np.float64)
90+
omega /= embed_dim / 2.0
91+
omega = 1.0 / 10000**omega # (D/2,)
92+
93+
pos = pos.reshape(-1) # (M,)
94+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
95+
96+
emb_sin = np.sin(out) # (M, D/2)
97+
emb_cos = np.cos(out) # (M, D/2)
98+
99+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
100+
return emb

0 commit comments

Comments
 (0)