Skip to content

Commit 1639171

Browse files
authored
Rename vae folder and update multinode training in readme (#56)
* rename vqvae * multi node training
1 parent ff2210e commit 1639171

File tree

10 files changed

+24
-5
lines changed

10 files changed

+24
-5
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ torchrun --standalone --nproc_per_node=2 train.py \
9797
```
9898

9999
We disable all speedup methods by default. Here are details of some key arguments for training:
100+
- `--nproc_per_node`: The GPU number you want to use for the current node.
100101
- `--plugin`: The booster plugin used by ColossalAI, `zero2` and `ddp` are supported. The default value is `zero2`. Recommend to enable `zero2`.
101102
- `--mixed_precision`: The data type for mixed precision training. The default value is `fp16`.
102103
- `--grad_checkpoint`: Whether enable the gradient checkpointing. This saves the memory cost during training process. The default value is `False`. Recommend to disable it when memory is enough.
@@ -107,6 +108,23 @@ We disable all speedup methods by default. Here are details of some key argument
107108

108109
For more details on the configuration of the training process, please visit our code.
109110

111+
<b>Multi-Node Training.</b>
112+
113+
To train OpenDiT on mutiple nodes, you can use the following command:
114+
115+
```
116+
colossalai run --nproc_per_node 8 --hostfile hostfile train.py \
117+
--model DiT-XL/2 \
118+
--batch_size 2
119+
```
120+
121+
And you need to create `hostfile` under the current dir. It should contain all IP address of your nodes and you need to make sure all nodes can be connected without password by ssh. An example of hostfile:
122+
123+
```
124+
111.111.111.111 # ip of node1
125+
222.222.222.222 # ip of node2
126+
```
127+
110128
<b>Inference.</b> You can perform inference using DiT model as follows. You need to replace the checkpoint path to your own trained model. Or you can download [official](https://github.com/facebookresearch/DiT?tab=readme-ov-file#sampling--) or [our](https://drive.google.com/file/d/1P4t2V3RDNcoCiEkbVWAjNetm3KC_4ueI/view?usp=drive_link) checkpoint for inference.
111129

112130
```shell
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision.io import write_video
66
from torchvision.utils import save_image
77

8-
from opendit.vqvae.wrapper import AutoencoderKLWrapper
8+
from opendit.vae.wrapper import AutoencoderKLWrapper
99

1010

1111
def t2v(x):

sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from opendit.models.diffusion import create_diffusion
1919
from opendit.models.dit import DiT_models
2020
from opendit.utils.download import find_model
21-
from opendit.vqvae.reconstruct import save_sample
22-
from opendit.vqvae.wrapper import AutoencoderKLWrapper
21+
from opendit.vae.reconstruct import save_sample
22+
from opendit.vae.wrapper import AutoencoderKLWrapper
2323

2424
torch.backends.cuda.matmul.allow_tf32 = True
2525
torch.backends.cudnn.allow_tf32 = True

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from opendit.utils.pg_utils import ProcessGroupManager
3333
from opendit.utils.train_utils import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, update_ema
3434
from opendit.utils.video_utils import DatasetFromCSV, get_transforms_image, get_transforms_video
35-
from opendit.vqvae.wrapper import AutoencoderKLWrapper
35+
from opendit.vae.wrapper import AutoencoderKLWrapper
3636

3737
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
3838
torch.backends.cuda.matmul.allow_tf32 = True
@@ -113,8 +113,9 @@ def main(args):
113113
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
114114
if args.use_video:
115115
# Wrap the VAE in a wrapper that handles video data
116-
# Use 3d patch size that is divisible by the input size
116+
# We use 2d vae from stableai instead of 3d vqvae from videogpt because it has better results
117117
vae = AutoencoderKLWrapper(vae)
118+
# Use 3d patch size that is divisible by the input size
118119
input_size = (args.num_frames, args.image_size, args.image_size)
119120
for i in range(3):
120121
assert input_size[i] % vae.patch_size[i] == 0, "Input size must be divisible by patch size"

0 commit comments

Comments
 (0)