Skip to content

Commit e5df80c

Browse files
ArEnScsayakpaula-r-r-o-w
authored
Full Finetuning for LTX possibily extended to other models. (#192)
* Full Finetuning for LTX possibily extended to other models. * Change name of the flag * Used disable grad for component on lora fine tuning enabled * Suggestions Addressed Renamed to SFT Added 2 other models. Testing required. * Switching to Full FineTuning * Run linter. * parse subfolder when needed. * tackle saving and loading hooks. * tackle validation. * fix subfolder bug. * remove __class__. * refactor * remove unnecessary changes * handle saving of final model weights correctly * remove unnecessary changes * LTX uses a default frame rate of 24 FPS We need to modify the output validation framerate to match that value. Add Framerate args. Add Update video output and inference frame rate * There was a results_args mapping that needed to be modified. * update * update README * Update README.md * update docs * add training configuration in cogvideox --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: Aryan <[email protected]>
1 parent a389807 commit e5df80c

File tree

17 files changed

+365
-131
lines changed

17 files changed

+365
-131
lines changed

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ FineTrainers is a work-in-progress library to support (accessible) training of v
1212

1313
## News
1414

15-
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
15+
- 🔥 **2024-01-13**: Support for T2V full-finetuning added! Thanks to @ArEnSc for taking up the initiative!
16+
- 🔥 **2024-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
1617
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
1718
- 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added!
1819

@@ -137,17 +138,16 @@ For inference, refer [here](./docs/training/ltx_video.md#inference). For docs re
137138

138139
<div align="center">
139140

140-
| **Model Name** | **Tasks** | **Min. GPU VRAM** |
141-
|:---:|:---:|:---:|
142-
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB |
143-
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB |
144-
| [CogVideoX](./docs/training/cogvideox.md) | Text-to-Video | 12GB<sup>*</sup> |
141+
| **Model Name** | **Tasks** | **Min. LoRA VRAM<sup>*</sup>** | **Min. Full Finetuning VRAM<sup>^</sup>** |
142+
|:------------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
143+
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB | 21 GB |
144+
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB | OOM |
145+
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 21 GB | 53 GB |
145146

146147
</div>
147148

148-
<sub><sup>*</sup>Noted for the 5B variant.</sub>
149-
150-
Note that the memory consumption in the table is reported with most of the options, discussed in [docs/training/optimizations](./docs/training/optimization.md), enabled.
149+
<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using fp8 weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
150+
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using bf16 weights & gradient checkpointing.</sub>
151151

152152
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).
153153

docs/training/README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
This directory contains the training-related specifications for all the models we support in `finetrainers`. Each model page has:
1+
# FineTrainers training documentation
22

3-
* an example training command
4-
* inference example
5-
* numbers on memory consumption
3+
This directory contains the training-related specifications for all the models we support in `finetrainers`. Each model page has:
4+
- an example training command
5+
- inference example
6+
- numbers on memory consumption
67

78
By default, we don't include any validation-related arguments in the example training commands. To enable validation inference, one can pass:
89

@@ -12,8 +13,13 @@ By default, we don't include any validation-related arguments in the example tra
1213
+ --validation_steps 100
1314
```
1415

15-
## Model-specific docs
16+
Supported models:
17+
- [CogVideoX](./cogvideox.md)
18+
- [LTX-Video](./ltx_video.md)
19+
- [HunyuanVideo](./hunyuan_video.md)
20+
21+
Supported training types:
22+
- LoRA (`--training_type lora`)
23+
- Full finetuning (`--training_type full-finetune`)
1624

17-
* [CogVideoX](./cogvideox.md)
18-
* [LTX-Video](./ltx_video.md)
19-
* [HunyuanVideo](./hunyuan_video.md)
25+
Arguments for training are well-documented in the code. For more information, please run `python train.py --help`.

docs/training/cogvideox.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Training
44

5+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6+
57
```bash
68
#!/bin/bash
79
export WANDB_MODE="offline"
@@ -84,6 +86,8 @@ echo -ne "-------------------- Finished executing script --------------------\n\
8486

8587
## Memory Usage
8688

89+
### LoRA
90+
8791
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x480x720` resolutions, **with precomputation**:
8892

8993
```
@@ -109,6 +113,31 @@ Training configuration: {
109113
| after validation end | 11.145 | 28.324 |
110114
| after training end | 11.144 | 11.592 |
111115

116+
### Full finetuning
117+
118+
```
119+
Training configuration: {
120+
"trainable parameters": 5570283072,
121+
"total samples": 1,
122+
"train epochs": 2,
123+
"train steps": 2,
124+
"batches per device": 1,
125+
"total batches observed per epoch": 1,
126+
"train batch size": 1,
127+
"gradient accumulation steps": 1
128+
}
129+
```
130+
131+
| stage | memory_allocated | max_memory_reserved |
132+
|:-----------------------------:|:-----------------:|:-------------------:|
133+
| after precomputing conditions | 8.880 | 8.941 |
134+
| after precomputing latents | 9.300 | 12.441 |
135+
| before training start | 10.376 | 10.387 |
136+
| after epoch 1 | 31.160 | 52.939 |
137+
| before validation start | 31.161 | 52.939 |
138+
| after validation end | 31.161 | 52.939 |
139+
| after training end | 31.160 | 34.295 |
140+
112141
## Supported checkpoints
113142

114143
CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working:

docs/training/hunyuan_video.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Training
44

5+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6+
57
```bash
68
#!/bin/bash
79

@@ -87,6 +89,8 @@ echo -ne "-------------------- Finished executing script --------------------\n\
8789

8890
## Memory Usage
8991

92+
### LoRA
93+
9094
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
9195

9296
```
@@ -139,6 +143,10 @@ Training configuration: {
139143

140144
Note: requires about `47` GB of VRAM with validation. If validation is not performed, the memory usage is reduced to about `42` GB.
141145

146+
### Full finetuning
147+
148+
Current, full finetuning is not supported for HunyuanVideo. It goes out of memory (OOM) for `49x512x768` resolutions.
149+
142150
## Inference
143151

144152
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:

docs/training/ltx_video.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Training
44

5-
Provided you have a dataset:
5+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
66

77
```bash
88
#!/bin/bash
@@ -88,6 +88,8 @@ echo -ne "-------------------- Finished executing script --------------------\n\
8888

8989
## Memory Usage
9090

91+
### LoRA
92+
9193
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
9294

9395
```
@@ -140,6 +142,31 @@ Training configuration: {
140142

141143
Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.
142144

145+
### Full Finetuning
146+
147+
```
148+
Training configuration: {
149+
"trainable parameters": 1923385472,
150+
"total samples": 1,
151+
"train epochs": 10,
152+
"train steps": 10,
153+
"batches per device": 1,
154+
"total batches observed per epoch": 1,
155+
"train batch size": 1,
156+
"gradient accumulation steps": 1
157+
}
158+
```
159+
160+
| stage | memory_allocated | max_memory_reserved |
161+
|:-----------------------------:|:----------------:|:-------------------:|
162+
| after precomputing conditions | 8.89 | 8.937 |
163+
| after precomputing latents | 9.701 | 11.615 |
164+
| before training start | 3.583 | 4.025 |
165+
| after epoch 1 | 10.769 | 20.357 |
166+
| before validation start | 10.769 | 20.357 |
167+
| after validation end | 10.769 | 28.332 |
168+
| after training end | 10.769 | 12.904 |
169+
143170
## Inference
144171

145172
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:

finetrainers/args.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class Args:
207207
Perform validation every `n` training steps.
208208
enable_model_cpu_offload (`bool`, defaults to `False`):
209209
Whether or not to offload different modeling components to CPU during validation.
210+
validation_frame_rate (`int`, defaults to `25`):
211+
Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline.
210212
211213
MISCELLANEOUS ARGUMENTS
212214
-----------------------
@@ -319,6 +321,7 @@ class Args:
319321
validation_every_n_epochs: Optional[int] = None
320322
validation_every_n_steps: Optional[int] = None
321323
enable_model_cpu_offload: bool = False
324+
validation_frame_rate: int = 25
322325

323326
# Miscellaneous arguments
324327
tracker_name: str = "finetrainers"
@@ -417,6 +420,7 @@ def to_dict(self) -> Dict[str, Any]:
417420
"validation_every_n_epochs": self.validation_every_n_epochs,
418421
"validation_every_n_steps": self.validation_every_n_steps,
419422
"enable_model_cpu_offload": self.enable_model_cpu_offload,
423+
"validation_frame_rate": self.validation_frame_rate,
420424
},
421425
"miscellaneous_arguments": {
422426
"tracker_name": self.tracker_name,
@@ -460,6 +464,7 @@ def parse_arguments() -> Args:
460464

461465

462466
def validate_args(args: Args):
467+
_validate_training_args(args)
463468
_validate_validation_args(args)
464469

465470

@@ -678,8 +683,9 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
678683
parser.add_argument(
679684
"--training_type",
680685
type=str,
686+
choices=["lora", "full-finetune"],
681687
required=True,
682-
help="Type of training to perform. Choose between ['lora']",
688+
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
683689
)
684690
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
685691
parser.add_argument(
@@ -713,7 +719,11 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
713719
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
714720
)
715721
parser.add_argument(
716-
"--target_modules", type=str, default="to_k to_q to_v to_out.0", nargs="+", help="The target modules for LoRA."
722+
"--target_modules",
723+
type=str,
724+
default=["to_k", "to_q", "to_v", "to_out.0"],
725+
nargs="+",
726+
help="The target modules for LoRA.",
717727
)
718728
parser.add_argument(
719729
"--gradient_accumulation_steps",
@@ -890,6 +900,12 @@ def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
890900
default=None,
891901
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
892902
)
903+
parser.add_argument(
904+
"--validation_frame_rate",
905+
type=int,
906+
default=25,
907+
help="Frame rate to use for the validation videos.",
908+
)
893909
parser.add_argument(
894910
"--enable_model_cpu_offload",
895911
action="store_true",
@@ -1085,6 +1101,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
10851101
result_args.validation_every_n_epochs = args.validation_epochs
10861102
result_args.validation_every_n_steps = args.validation_steps
10871103
result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
1104+
result_args.validation_frame_rate = args.validation_frame_rate
10881105

10891106
# Miscellaneous arguments
10901107
result_args.tracker_name = args.tracker_name
@@ -1100,6 +1117,15 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
11001117
return result_args
11011118

11021119

1120+
def _validate_training_args(args: Args):
1121+
if args.training_type == "lora":
1122+
assert args.rank is not None, "Rank is required for LoRA training"
1123+
assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training"
1124+
assert (
1125+
args.target_modules is not None and len(args.target_modules) > 0
1126+
), "Target modules are required for LoRA training"
1127+
1128+
11031129
def _validate_validation_args(args: Args):
11041130
assert args.validation_prompts is not None, "Validation prompts are required for validation"
11051131
if args.validation_images is not None:

finetrainers/cogvideox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG
2+
from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG

finetrainers/cogvideox/cogvideox_lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def _pad_frames(latents: torch.Tensor, patch_size_t: int):
311311
return latents
312312

313313

314+
# TODO(aryan): refactor into model specs for better re-use
314315
COGVIDEOX_T2V_LORA_CONFIG = {
315316
"pipeline_cls": CogVideoXPipeline,
316317
"load_condition_models": load_condition_models,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from diffusers import CogVideoXPipeline
2+
3+
from .cogvideox_lora import (
4+
calculate_noisy_latents,
5+
collate_fn_t2v,
6+
forward_pass,
7+
initialize_pipeline,
8+
load_condition_models,
9+
load_diffusion_models,
10+
load_latent_models,
11+
post_latent_preparation,
12+
prepare_conditions,
13+
prepare_latents,
14+
validation,
15+
)
16+
17+
18+
# TODO(aryan): refactor into model specs for better re-use
19+
COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = {
20+
"pipeline_cls": CogVideoXPipeline,
21+
"load_condition_models": load_condition_models,
22+
"load_latent_models": load_latent_models,
23+
"load_diffusion_models": load_diffusion_models,
24+
"initialize_pipeline": initialize_pipeline,
25+
"prepare_conditions": prepare_conditions,
26+
"prepare_latents": prepare_latents,
27+
"post_latent_preparation": post_latent_preparation,
28+
"collate_fn": collate_fn_t2v,
29+
"calculate_noisy_latents": calculate_noisy_latents,
30+
"forward_pass": forward_pass,
31+
"validation": validation,
32+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG
12
from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG

0 commit comments

Comments
 (0)