Skip to content

Commit d220aac

Browse files
authored
(fake*) FP8 training support (#184)
* remove mixed_precision * update * make style * update * better defaults for experimenting * fix train continuation after validation error * update READMEs * remove granularity * update hook implementation to latest diffusers) * update * update * remove unused patches * remove mixed precision in tests * add changes lost in merge conflict resolution * update README date
1 parent f5f9cc0 commit d220aac

File tree

19 files changed

+513
-99
lines changed

19 files changed

+513
-99
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ FineTrainers is a work-in-progress library to support (accessible) training of v
1212

1313
## News
1414

15+
- 🔥 **2024-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
1516
- 🔥 **2024-01-13**: Support for T2V full-finetuning added! Thanks to @ArEnSc for taking up the initiative!
1617
- 🔥 **2024-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
1718
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
@@ -83,7 +84,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
8384
# Training arguments
8485
training_cmd="--training_type lora \
8586
--seed 42 \
86-
--mixed_precision bf16 \
8787
--batch_size 1 \
8888
--train_steps 3000 \
8989
--rank 128 \
@@ -140,14 +140,14 @@ For inference, refer [here](./docs/training/ltx_video.md#inference). For docs re
140140

141141
| **Model Name** | **Tasks** | **Min. LoRA VRAM<sup>*</sup>** | **Min. Full Finetuning VRAM<sup>^</sup>** |
142142
|:------------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
143-
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB | 21 GB |
144-
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB | OOM |
145-
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 21 GB | 53 GB |
143+
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
144+
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
145+
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |
146146

147147
</div>
148148

149-
<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using fp8 weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
150-
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using bf16 weights & gradient checkpointing.</sub>
149+
<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
150+
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.</sub>
151151

152152
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).
153153

accelerate_configs/compiled_1.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ enable_cpu_affinity: false
1111
gpu_ids: '3'
1212
machine_rank: 0
1313
main_training_function: main
14-
mixed_precision: fp16
14+
mixed_precision: bf16
1515
num_machines: 1
1616
num_processes: 1
1717
rdzv_backend: static

accelerate_configs/uncompiled_1.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ enable_cpu_affinity: false
66
gpu_ids: '3'
77
machine_rank: 0
88
main_training_function: main
9-
mixed_precision: fp16
9+
mixed_precision: bf16
1010
num_machines: 1
1111
num_processes: 1
1212
rdzv_backend: static

docs/training/cogvideox.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ dataloader_cmd="--dataloader_num_workers 4"
3737
# Training arguments
3838
training_cmd="--training_type lora \
3939
--seed 42 \
40-
--mixed_precision bf16 \
4140
--batch_size 1 \
4241
--precompute_conditions \
4342
--train_steps 1000 \
@@ -88,6 +87,12 @@ echo -ne "-------------------- Finished executing script --------------------\n\
8887

8988
### LoRA
9089

90+
<!-- TODO(aryan): Update these numbers for 49x512x768 -->
91+
92+
> [!NOTE]
93+
>
94+
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
95+
9196
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x480x720` resolutions, **with precomputation**:
9297

9398
```

docs/training/hunyuan_video.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ diffusion_cmd=""
4242
# Training arguments
4343
training_cmd="--training_type lora \
4444
--seed 42 \
45-
--mixed_precision bf16 \
4645
--batch_size 1 \
4746
--train_steps 500 \
4847
--rank 128 \
@@ -91,6 +90,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\
9190

9291
### LoRA
9392

93+
> [!NOTE]
94+
>
95+
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
96+
9497
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
9598

9699
```

docs/training/ltx_video.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
4141
# Training arguments
4242
training_cmd="--training_type lora \
4343
--seed 42 \
44-
--mixed_precision bf16 \
4544
--batch_size 1 \
4645
--train_steps 3000 \
4746
--rank 128 \
@@ -90,6 +89,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\
9089

9190
### LoRA
9291

92+
> [!NOTE]
93+
>
94+
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
95+
9396
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
9497

9598
```

docs/training/optimization.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
# Memory optimizations
2+
13
To lower memory requirements during training:
24

5+
- `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory.
6+
- `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass.
7+
- `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`)
8+
- `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states.
39
- Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example).
4-
- Pass `--precompute_conditions` when launching training.
5-
- Pass `--gradient_checkpointing` when launching training.
6-
- Pass `--use_8bit_bnb` when launching training. Note that this is only applicable to Adam and AdamW optimizers.
710
- Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
811

9-
We will continue to add more features that help to reduce memory consumption.
12+
We will continue to add more features that help to reduce memory consumption.

finetrainers/args.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ class Args:
4343
Data type for the transformer model.
4444
vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
4545
Data type for the VAE model.
46+
layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
47+
Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
48+
layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
49+
Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
50+
layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
51+
Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
52+
naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
53+
by default, and recommend adding more layers to the default list based on the model architecture.
4654
4755
DATASET ARGUMENTS
4856
-----------------
@@ -126,8 +134,6 @@ class Args:
126134
Type of training to perform. Choose between ['lora'].
127135
seed (`int`, defaults to `42`):
128136
A seed for reproducible training.
129-
mixed_precision (`str`, defaults to `None`):
130-
Whether to use mixed precision. Choose between ['no', 'fp8', 'fp16', 'bf16'].
131137
batch_size (`int`, defaults to `1`):
132138
Per-device batch size.
133139
train_epochs (`int`, defaults to `1`):
@@ -243,6 +249,18 @@ class Args:
243249
text_encoder_3_dtype: torch.dtype = torch.bfloat16
244250
transformer_dtype: torch.dtype = torch.bfloat16
245251
vae_dtype: torch.dtype = torch.bfloat16
252+
layerwise_upcasting_modules: List[str] = []
253+
layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
254+
layerwise_upcasting_skip_modules_pattern: List[str] = [
255+
"patch_embed",
256+
"pos_embed",
257+
"x_embedder",
258+
"context_embedder",
259+
"time_embed",
260+
"^proj_in$",
261+
"^proj_out$",
262+
"norm",
263+
]
246264

247265
# Dataset arguments
248266
data_root: str = None
@@ -277,9 +295,6 @@ class Args:
277295
# Training arguments
278296
training_type: str = None
279297
seed: int = 42
280-
mixed_precision: str = (
281-
None # TODO: consider removing later https://github.com/a-r-r-o-w/finetrainers/pull/139#discussion_r1897438414
282-
)
283298
batch_size: int = 1
284299
train_epochs: int = 1
285300
train_steps: int = None
@@ -347,6 +362,9 @@ def to_dict(self) -> Dict[str, Any]:
347362
"text_encoder_3_dtype": self.text_encoder_3_dtype,
348363
"transformer_dtype": self.transformer_dtype,
349364
"vae_dtype": self.vae_dtype,
365+
"layerwise_upcasting_modules": self.layerwise_upcasting_modules,
366+
"layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
367+
"layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
350368
},
351369
"dataset_arguments": {
352370
"data_root": self.data_root,
@@ -381,7 +399,6 @@ def to_dict(self) -> Dict[str, Any]:
381399
"training_arguments": {
382400
"training_type": self.training_type,
383401
"seed": self.seed,
384-
"mixed_precision": self.mixed_precision,
385402
"batch_size": self.batch_size,
386403
"train_epochs": self.train_epochs,
387404
"train_steps": self.train_steps,
@@ -464,6 +481,7 @@ def parse_arguments() -> Args:
464481

465482

466483
def validate_args(args: Args):
484+
_validated_model_args(args)
467485
_validate_training_args(args)
468486
_validate_validation_args(args)
469487

@@ -506,6 +524,28 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
506524
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
507525
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
508526
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
527+
parser.add_argument(
528+
"--layerwise_upcasting_modules",
529+
type=str,
530+
default=[],
531+
nargs="+",
532+
choices=["transformer"],
533+
help="Modules that should have fp8 storage weights but higher precision computation.",
534+
)
535+
parser.add_argument(
536+
"--layerwise_upcasting_storage_dtype",
537+
type=str,
538+
default="float8_e4m3fn",
539+
choices=["float8_e4m3fn", "float8_e5m2"],
540+
help="Data type for the layerwise upcasting storage.",
541+
)
542+
parser.add_argument(
543+
"--layerwise_upcasting_skip_modules_pattern",
544+
type=str,
545+
default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
546+
nargs="+",
547+
help="Modules to skip for layerwise upcasting.",
548+
)
509549

510550

511551
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
@@ -688,16 +728,6 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
688728
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
689729
)
690730
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
691-
parser.add_argument(
692-
"--mixed_precision",
693-
type=str,
694-
default="no",
695-
choices=["no", "fp8", "fp16", "bf16"],
696-
help=(
697-
"Whether to use mixed precision. Defaults to the value of accelerate config of the current system or the "
698-
"flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
699-
),
700-
)
701731
parser.add_argument(
702732
"--batch_size",
703733
type=int,
@@ -979,8 +1009,9 @@ def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
9791009
"bf16": torch.bfloat16,
9801010
"fp16": torch.float16,
9811011
"fp32": torch.float32,
1012+
"float8_e4m3fn": torch.float8_e4m3fn,
1013+
"float8_e5m2": torch.float8_e5m2,
9821014
}
983-
_INVERSE_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()}
9841015

9851016

9861017
def _map_to_args_type(args: Dict[str, Any]) -> Args:
@@ -997,6 +1028,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
9971028
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
9981029
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
9991030
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
1031+
result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
1032+
result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
1033+
result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
10001034

10011035
# Dataset arguments
10021036
if args.data_root is None and args.dataset_file is None:
@@ -1034,7 +1068,6 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
10341068
# Training arguments
10351069
result_args.training_type = args.training_type
10361070
result_args.seed = args.seed
1037-
result_args.mixed_precision = args.mixed_precision
10381071
result_args.batch_size = args.batch_size
10391072
result_args.train_epochs = args.train_epochs
10401073
result_args.train_steps = args.train_steps
@@ -1117,6 +1150,13 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
11171150
return result_args
11181151

11191152

1153+
def _validated_model_args(args: Args):
1154+
if args.training_type == "full-finetune":
1155+
assert (
1156+
"transformer" not in args.layerwise_upcasting_modules
1157+
), "Layerwise upcasting is not supported for full-finetune training"
1158+
1159+
11201160
def _validate_training_args(args: Args):
11211161
if args.training_type == "lora":
11221162
assert args.rank is not None, "Rank is required for LoRA training"

finetrainers/cogvideox/lora.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def initialize_pipeline(
6565
enable_slicing: bool = False,
6666
enable_tiling: bool = False,
6767
enable_model_cpu_offload: bool = False,
68+
is_training: bool = False,
6869
**kwargs,
6970
) -> CogVideoXPipeline:
7071
component_name_pairs = [
@@ -81,9 +82,14 @@ def initialize_pipeline(
8182

8283
pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
8384
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
84-
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
8585
pipe.vae = pipe.vae.to(dtype=vae_dtype)
8686

87+
# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
88+
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
89+
# DDP optimizer step.
90+
if not is_training:
91+
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
92+
8793
if enable_slicing:
8894
pipe.vae.enable_slicing()
8995
if enable_tiling:

finetrainers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .layerwise_upcasting import apply_layerwise_upcasting

0 commit comments

Comments
 (0)