Skip to content

Commit 3a519f5

Browse files
authored
Full finetuning memory requirements (#9)
* update * update * model cpu offloading
1 parent be9d99a commit 3a519f5

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

README.md

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ Supported and verified memory optimizations for training include:
108108

109109
> [!IMPORTANT]
110110
> The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
111+
>
112+
> If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offloading`.
111113
112114
### LoRA finetuning
113115

@@ -307,7 +309,64 @@ With `train_batch_size = 4`:
307309
### Full finetuning
308310

309311
> [!NOTE]
310-
> `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose to not perform validation/testing as part of the training script.
312+
> Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
313+
314+
<details>
315+
<summary> AdamW </summary>
316+
317+
With `train_batch_size = 1`:
318+
319+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
320+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
321+
| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
322+
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
323+
324+
With `train_batch_size = 4`:
325+
326+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
327+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
328+
| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
329+
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
330+
331+
</details>
332+
333+
<details>
334+
<summary> AdamW (8-bit bitsandbytes) </summary>
335+
336+
With `train_batch_size = 1`:
337+
338+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
339+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
340+
| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
341+
| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
342+
343+
With `train_batch_size = 4`:
344+
345+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
346+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
347+
| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
348+
| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
349+
350+
</details>
351+
352+
<details>
353+
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
354+
355+
With `train_batch_size = 1`:
356+
357+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
358+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
359+
| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
360+
| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
361+
362+
With `train_batch_size = 4`:
363+
364+
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
365+
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
366+
| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
367+
| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
368+
369+
</details>
311370

312371
<details>
313372
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
@@ -331,7 +390,10 @@ With `train_batch_size = 4`:
331390

332391
</details>
333392

334-
- [ ] Make scripts compatible with DDP
393+
> [!NOTE]
394+
> `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose to not perform validation/testing as part of the training script.
395+
396+
- [x] Make scripts compatible with DDP
335397
- [ ] Make scripts compatible with FSDP
336398
- [x] Make scripts compatible with DeepSpeed
337399
- [x] Test scripts with memory-efficient optimizer from bitsandbytes

training/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def _get_validation_args(parser: argparse.ArgumentParser) -> None:
140140
default=False,
141141
help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
142142
)
143+
parser.add_argument(
144+
"--enable_model_cpu_offloading",
145+
action="store_true",
146+
default=False,
147+
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory."
148+
)
143149

144150

145151
def _get_training_args(parser: argparse.ArgumentParser) -> None:

training/cogvideox_text_to_video_lora.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,8 @@ def collate_fn(data):
779779
pipe.vae.enable_slicing()
780780
if args.enable_tiling:
781781
pipe.vae.enable_tiling()
782+
if args.enable_model_cpu_offload:
783+
pipe.enable_model_cpu_offload()
782784

783785
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
784786
for validation_prompt in validation_prompts:
@@ -853,6 +855,8 @@ def collate_fn(data):
853855
pipe.vae.enable_slicing()
854856
if args.enable_tiling:
855857
pipe.vae.enable_tiling()
858+
if args.enable_model_cpu_offload:
859+
pipe.enable_model_cpu_offload()
856860

857861
# Load LoRA weights
858862
lora_scaling = args.lora_alpha / args.rank

training/cogvideox_text_to_video_sft.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,8 @@ def collate_fn(data):
710710
pipe.vae.enable_slicing()
711711
if args.enable_tiling:
712712
pipe.vae.enable_tiling()
713+
if args.enable_model_cpu_offload:
714+
pipe.enable_model_cpu_offload()
713715

714716
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
715717
for validation_prompt in validation_prompts:
@@ -785,6 +787,8 @@ def collate_fn(data):
785787
pipe.vae.enable_slicing()
786788
if args.enable_tiling:
787789
pipe.vae.enable_tiling()
790+
if args.enable_model_cpu_offload:
791+
pipe.enable_model_cpu_offload()
788792

789793
# Run inference
790794
validation_outputs = []

0 commit comments

Comments
 (0)