Skip to content

Commit be9d99a

Browse files
a-r-r-o-wsayakpaul
andauthored
DeepSpeed and DDP Configs (#10)
* add configs * remove compiled ddp config * add coauthor Co-Authored-By: Sayak Paul <[email protected]> * update * deepspeed numbers nd fixes --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 795f2c2 commit be9d99a

File tree

6 files changed

+97
-13
lines changed

6 files changed

+97
-13
lines changed

README.md

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ TODO: Add a section on creating and using precomputed embeddings.
8585

8686
We provide training script for both text-to-video and image-to-video generation which are compatible with the [Cog family of models](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce).
8787

88-
Take a look at `training/*.sh`
88+
Take a look at `*.sh`
8989

9090
Note: Untested on MPS
9191

@@ -282,11 +282,55 @@ ValueError: Expected a cuda device, but got: cpu
282282

283283
</details>
284284

285+
<details>
286+
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
287+
288+
> [!NOTE]
289+
> Results are for `lora_rank=256` with `gradient_checkpointing` enabled, 2x RTX 4090.
290+
291+
With `train_batch_size = 1`:
292+
293+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
294+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
295+
| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
296+
| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
297+
298+
With `train_batch_size = 4`:
299+
300+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
301+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
302+
| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
303+
| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
304+
305+
</details>
306+
285307
### Full finetuning
286308

287309
> [!NOTE]
288310
> `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose to not perform validation/testing as part of the training script.
289311

312+
<details>
313+
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
314+
315+
> [!NOTE]
316+
> Results with `gradient_checkpointing` enabled, 2x RTX 4090.
317+
318+
With `train_batch_size = 1`:
319+
320+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
321+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
322+
| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
323+
| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
324+
325+
With `train_batch_size = 4`:
326+
327+
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
328+
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
329+
| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
330+
| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
331+
332+
</details>
333+
290334
- [ ] Make scripts compatible with DDP
291335
- [ ] Make scripts compatible with FSDP
292336
- [x] Make scripts compatible with DeepSpeed

accelerate_configs/deepspeed.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
gradient_accumulation_steps: 1
5+
gradient_clipping: 1.0
6+
offload_optimizer_device: cpu
7+
offload_param_device: cpu
8+
zero3_init_flag: false
9+
zero_stage: 2
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
enable_cpu_affinity: false
13+
machine_rank: 0
14+
main_training_function: main
15+
mixed_precision: bf16
16+
num_machines: 1
17+
num_processes: 2
18+
rdzv_backend: static
19+
same_network: true
20+
tpu_env: []
21+
tpu_use_cluster: false
22+
tpu_use_sudo: false
23+
use_cpu: false
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: MULTI_GPU
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
gpu_ids: 0,1
7+
machine_rank: 0
8+
main_training_function: main
9+
mixed_precision: bf16
10+
num_machines: 1
11+
num_processes: 2
12+
rdzv_backend: static
13+
same_network: true
14+
tpu_env: []
15+
tpu_use_cluster: false
16+
tpu_use_sudo: false
17+
use_cpu: false

train_text_to_video_sft.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ for learning_rate in "${LEARNING_RATES[@]}"; do
5555
--gradient_checkpointing \
5656
--learning_rate $learning_rate \
5757
--lr_scheduler $lr_schedule \
58-
--lr_warmup_steps 200 \
58+
--lr_warmup_steps 800 \
5959
--lr_num_cycles 1 \
6060
--enable_slicing \
6161
--enable_tiling \
@@ -65,7 +65,7 @@ for learning_rate in "${LEARNING_RATES[@]}"; do
6565
--weight_decay 0.001 \
6666
--max_grad_norm 1.0 \
6767
--allow_tf32 \
68-
--report_to wandb
68+
--report_to wandb \
6969
--nccl_timeout 1800"
7070

7171
echo "Running command: $cmd"

training/cogvideox_text_to_video_lora.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import torch
2727
import transformers
2828
import wandb
29-
from accelerate import Accelerator
29+
from accelerate import Accelerator, DistributedType
3030
from accelerate.logging import get_logger
3131
from accelerate.utils import (
3232
DistributedDataParallelKwargs,
@@ -315,7 +315,7 @@ def main(args):
315315
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
316316
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
317317
):
318-
weight_dtype = torch.float16
318+
weight_dtype = torch.bfloat16
319319
else:
320320
if accelerator.mixed_precision == "fp16":
321321
weight_dtype = torch.float16
@@ -631,7 +631,7 @@ def collate_fn(data):
631631

632632
videos = latent_dist.sample() * VAE_SCALING_FACTOR
633633
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
634-
videos = videos.to(memory_format=torch.contiguous_format).float()
634+
videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
635635
model_input = videos
636636

637637
# Encode prompts
@@ -646,7 +646,7 @@ def collate_fn(data):
646646
requires_grad=False,
647647
)
648648
else:
649-
prompt_embeds = prompts
649+
prompt_embeds = prompts.to(dtype=weight_dtype)
650650

651651
# Sample noise that will be added to the latents
652652
noise = torch.randn_like(model_input)
@@ -721,7 +721,7 @@ def collate_fn(data):
721721
progress_bar.update(1)
722722
global_step += 1
723723

724-
if accelerator.is_main_process:
724+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
725725
if global_step % args.checkpointing_steps == 0:
726726
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
727727
if args.checkpoints_total_limit is not None:

training/cogvideox_text_to_video_sft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import torch
2727
import transformers
2828
import wandb
29-
from accelerate import Accelerator
29+
from accelerate import Accelerator, DistributedType
3030
from accelerate.logging import get_logger
3131
from accelerate.utils import (
3232
DistributedDataParallelKwargs,
@@ -271,7 +271,7 @@ def main(args):
271271
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
272272
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
273273
):
274-
weight_dtype = torch.float16
274+
weight_dtype = torch.bfloat16
275275
else:
276276
if accelerator.mixed_precision == "fp16":
277277
weight_dtype = torch.float16
@@ -562,7 +562,7 @@ def collate_fn(data):
562562

563563
videos = latent_dist.sample() * VAE_SCALING_FACTOR
564564
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
565-
videos = videos.to(memory_format=torch.contiguous_format).float()
565+
videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
566566
model_input = videos
567567

568568
# Encode prompts
@@ -577,7 +577,7 @@ def collate_fn(data):
577577
requires_grad=False,
578578
)
579579
else:
580-
prompt_embeds = prompts
580+
prompt_embeds = prompts.to(dtype=weight_dtype)
581581

582582
# Sample noise that will be added to the latents
583583
noise = torch.randn_like(model_input)
@@ -652,7 +652,7 @@ def collate_fn(data):
652652
progress_bar.update(1)
653653
global_step += 1
654654

655-
if accelerator.is_main_process:
655+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
656656
if global_step % args.checkpointing_steps == 0:
657657
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
658658
if args.checkpoints_total_limit is not None:

0 commit comments

Comments
 (0)