Skip to content

Commit a080f0d

Browse files
authored
[Training Utils] create a utility for casting the lora params during training. (#6553)
create a utility for casting the lora params during training.
1 parent 79df503 commit a080f0d

File tree

5 files changed

+26
-28
lines changed

5 files changed

+26
-28
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
UNet2DConditionModel,
5252
)
5353
from diffusers.optimization import get_scheduler
54-
from diffusers.training_utils import resolve_interpolation_mode
54+
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
5555
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5656
from diffusers.utils.import_utils import is_xformers_available
5757

@@ -860,10 +860,8 @@ def main(args):
860860

861861
# Make sure the trainable params are in float32.
862862
if args.mixed_precision == "fp16":
863-
for param in unet.parameters():
864-
# only upcast trainable parameters (LoRA) into fp32
865-
if param.requires_grad:
866-
param.data = param.to(torch.float32)
863+
# only upcast trainable parameters (LoRA) into fp32
864+
cast_training_params(unet, dtype=torch.float32)
867865

868866
# Also move the alpha and sigma noise schedules to accelerator.device.
869867
alpha_schedule = alpha_schedule.to(accelerator.device)

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from diffusers.loaders import LoraLoaderMixin
5555
from diffusers.optimization import get_scheduler
56-
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
56+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
5757
from diffusers.utils import (
5858
check_min_version,
5959
convert_state_dict_to_diffusers,
@@ -1086,11 +1086,8 @@ def load_model_hook(models, input_dir):
10861086
models = [unet_]
10871087
if args.train_text_encoder:
10881088
models.extend([text_encoder_one_, text_encoder_two_])
1089-
for model in models:
1090-
for param in model.parameters():
1091-
# only upcast trainable parameters (LoRA) into fp32
1092-
if param.requires_grad:
1093-
param.data = param.to(torch.float32)
1089+
# only upcast trainable parameters (LoRA) into fp32
1090+
cast_training_params(models)
10941091

10951092
accelerator.register_save_state_pre_hook(save_model_hook)
10961093
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1110,11 +1107,9 @@ def load_model_hook(models, input_dir):
11101107
models = [unet]
11111108
if args.train_text_encoder:
11121109
models.extend([text_encoder_one, text_encoder_two])
1113-
for model in models:
1114-
for param in model.parameters():
1115-
# only upcast trainable parameters (LoRA) into fp32
1116-
if param.requires_grad:
1117-
param.data = param.to(torch.float32)
1110+
1111+
# only upcast trainable parameters (LoRA) into fp32
1112+
cast_training_params(models, dtype=torch.float32)
11181113

11191114
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
11201115

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import diffusers
4444
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
4545
from diffusers.optimization import get_scheduler
46-
from diffusers.training_utils import compute_snr
46+
from diffusers.training_utils import cast_training_params, compute_snr
4747
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
4848
from diffusers.utils.import_utils import is_xformers_available
4949

@@ -466,10 +466,8 @@ def main():
466466
# Add adapter and make sure the trainable params are in float32.
467467
unet.add_adapter(unet_lora_config)
468468
if args.mixed_precision == "fp16":
469-
for param in unet.parameters():
470-
# only upcast trainable parameters (LoRA) into fp32
471-
if param.requires_grad:
472-
param.data = param.to(torch.float32)
469+
# only upcast trainable parameters (LoRA) into fp32
470+
cast_training_params(unet, dtype=torch.float32)
473471

474472
if args.enable_xformers_memory_efficient_attention:
475473
if is_xformers_available():

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252
from diffusers.loaders import LoraLoaderMixin
5353
from diffusers.optimization import get_scheduler
54-
from diffusers.training_utils import compute_snr
54+
from diffusers.training_utils import cast_training_params, compute_snr
5555
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5656
from diffusers.utils.import_utils import is_xformers_available
5757

@@ -634,11 +634,8 @@ def main(args):
634634
models = [unet]
635635
if args.train_text_encoder:
636636
models.extend([text_encoder_one, text_encoder_two])
637-
for model in models:
638-
for param in model.parameters():
639-
# only upcast trainable parameters (LoRA) into fp32
640-
if param.requires_grad:
641-
param.data = param.to(torch.float32)
637+
# only upcast trainable parameters (LoRA) into fp32
638+
cast_training_params(models, dtype=torch.float32)
642639

643640
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
644641
def save_model_hook(models, weights, output_dir):

src/diffusers/training_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import copy
33
import random
4-
from typing import Any, Dict, Iterable, Optional, Union
4+
from typing import Any, Dict, Iterable, List, Optional, Union
55

66
import numpy as np
77
import torch
@@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
121121
return lora_state_dict
122122

123123

124+
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
125+
if not isinstance(model, list):
126+
model = [model]
127+
for m in model:
128+
for param in m.parameters():
129+
# only upcast trainable parameters into fp32
130+
if param.requires_grad:
131+
param.data = param.to(dtype)
132+
133+
124134
def _set_state_dict_into_text_encoder(
125135
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
126136
):

0 commit comments

Comments
 (0)