diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index afe30680567d..5b78501f9b49 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -39,7 +39,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -59,12 +59,13 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, + convert_unet_state_dict_to_peft, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card @@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") + lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)