Skip to content

Commit c15d82b

Browse files
authored
fix sample and update readme (#79)
* fix sample and update readme * Update num_classes argument in README.md * Update ckpt argument in sample.py
1 parent cdc4d5f commit c15d82b

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ We disable all speedup methods by default. Here are details of some key argument
114114
- `--enable_modulate_kernel`: Whether enable the modulate kernel optimization. This speeds up the training process. The default value is `False`. This kernel will cause NaN under some circumstances. So we recommend to disable it for now.
115115
- `--sequence_parallel_size`: The sequence parallelism size. Will enable sequence parallelism when setting a value > 1. The default value is 1. Recommend to disable it if memory is enough.
116116
- `--load`: Load previous saved checkpoint dir and continue training.
117-
- `--num_classes`: Label class number. Only used for label-to-image generation.
117+
- `--num_classes`: Label class number. Should be 10 for CIFAR10 and 1000 for ImageNet. Only used for label-to-image generation.
118118

119119

120120
For more details on the configuration of the training process, please visit our code.
@@ -126,7 +126,8 @@ To train OpenDiT on multiple nodes, you can use the following command:
126126
```
127127
colossalai run --nproc_per_node 8 --hostfile hostfile train.py \
128128
--model DiT-XL/2 \
129-
--batch_size 2
129+
--batch_size 2 \
130+
--num_classes 10
130131
```
131132

132133
And you need to create `hostfile` under the current dir. It should contain all IP address of your nodes and you need to make sure all nodes can be connected without password by ssh. An example of hostfile:
@@ -142,8 +143,15 @@ And you need to create `hostfile` under the current dir. It should contain all I
142143
# Use script
143144
bash sample_img.sh
144145
# Use command line
145-
python sample.py --model DiT-XL/2 --image_size 256 --ckpt ./model.pt
146+
python sample.py \
147+
--model DiT-XL/2 \
148+
--image_size 256 \
149+
--num_classes 10 \
150+
--ckpt ckpt_path
146151
```
152+
Here are details of some addtional key arguments for inference:
153+
- `--ckpt`: The weight of ema model `ema.pt`. To check your training progress, it can also be our saved base model `epochXX-global_stepXX/model`, it will produce better results than ema in early training stage.
154+
- `--num_classes`: Label class number. Should be 10 for CIFAR10, and 1000 for ImageNet (including official and our checkpoint).
147155

148156
### Video
149157
<b>Training.</b> We current support `VDiT` and `Latte` for video generation. VDiT adopts DiT structure and use video as inputs data. Latte further use more efficient spatial & temporal blocks based on VDiT (not exactly align with origin [Latte](https://github.com/Vchitect/Latte)).

sample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,14 @@ def main(args):
9292
y = class_labels * 2
9393
else:
9494
# Labels to condition the model with (feel free to change):
95-
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
95+
if args.num_classes == 1000:
96+
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
97+
else:
98+
class_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
9699
n = len(class_labels)
97100
z = torch.randn(n, 4, input_size, input_size, device=device)
98101
y = torch.tensor(class_labels, device=device)
99-
y_null = torch.tensor([1000] * n, device=device)
102+
y_null = torch.tensor([0] * n, device=device)
100103
y = torch.cat([y, y_null], 0)
101104

102105
# Setup classifier-free guidance:

sample_img.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
python sample.py --model DiT-XL/2 --image_size 256 --ckpt ./pretrained/DiT-XL-2-256x256.pt
1+
python sample.py \
2+
--model DiT-XL/2 \
3+
--image_size 256 \
4+
--num_classes 10 \
5+
--ckpt ckpt_path

0 commit comments

Comments
 (0)