3434from huggingface_hub import create_repo , upload_folder
3535from huggingface_hub .utils import insecure_hashlib
3636from packaging import version
37- from peft import LoraConfig
37+ from peft import LoraConfig , set_peft_model_state_dict
3838from peft .utils import get_peft_model_state_dict
3939from PIL import Image
4040from PIL .ImageOps import exif_transpose
5353)
5454from diffusers .loaders import LoraLoaderMixin
5555from diffusers .optimization import get_scheduler
56- from diffusers .training_utils import compute_snr
57- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
56+ from diffusers .training_utils import _set_state_dict_into_text_encoder , compute_snr
57+ from diffusers .utils import (
58+ check_min_version ,
59+ convert_state_dict_to_diffusers ,
60+ convert_unet_state_dict_to_peft ,
61+ is_wandb_available ,
62+ )
5863from diffusers .utils .import_utils import is_xformers_available
5964from diffusers .utils .torch_utils import is_compiled_module
6065
@@ -997,17 +1002,6 @@ def main(args):
9971002 text_encoder_one .add_adapter (text_lora_config )
9981003 text_encoder_two .add_adapter (text_lora_config )
9991004
1000- # Make sure the trainable params are in float32.
1001- if args .mixed_precision == "fp16" :
1002- models = [unet ]
1003- if args .train_text_encoder :
1004- models .extend ([text_encoder_one , text_encoder_two ])
1005- for model in models :
1006- for param in model .parameters ():
1007- # only upcast trainable parameters (LoRA) into fp32
1008- if param .requires_grad :
1009- param .data = param .to (torch .float32 )
1010-
10111005 def unwrap_model (model ):
10121006 model = accelerator .unwrap_model (model )
10131007 model = model ._orig_mod if is_compiled_module (model ) else model
@@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir):
10641058 raise ValueError (f"unexpected save model: { model .__class__ } " )
10651059
10661060 lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
1067- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
10681061
1069- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
1070- LoraLoaderMixin .load_lora_into_text_encoder (
1071- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
1072- )
1062+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
1063+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
1064+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
1065+ if incompatible_keys is not None :
1066+ # check only for unexpected keys
1067+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1068+ if unexpected_keys :
1069+ logger .warning (
1070+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1071+ f" { unexpected_keys } . "
1072+ )
10731073
1074- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
1075- LoraLoaderMixin .load_lora_into_text_encoder (
1076- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
1077- )
1074+ if args .train_text_encoder :
1075+ # Do we need to call `scale_lora_layers()` here?
1076+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
1077+
1078+ _set_state_dict_into_text_encoder (
1079+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_one_
1080+ )
1081+
1082+ # Make sure the trainable params are in float32. This is again needed since the base models
1083+ # are in `weight_dtype`. More details:
1084+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1085+ if args .mixed_precision == "fp16" :
1086+ models = [unet_ ]
1087+ if args .train_text_encoder :
1088+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
1089+ for model in models :
1090+ for param in model .parameters ():
1091+ # only upcast trainable parameters (LoRA) into fp32
1092+ if param .requires_grad :
1093+ param .data = param .to (torch .float32 )
10781094
10791095 accelerator .register_save_state_pre_hook (save_model_hook )
10801096 accelerator .register_load_state_pre_hook (load_model_hook )
@@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir):
10891105 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
10901106 )
10911107
1108+ # Make sure the trainable params are in float32.
1109+ if args .mixed_precision == "fp16" :
1110+ models = [unet ]
1111+ if args .train_text_encoder :
1112+ models .extend ([text_encoder_one , text_encoder_two ])
1113+ for model in models :
1114+ for param in model .parameters ():
1115+ # only upcast trainable parameters (LoRA) into fp32
1116+ if param .requires_grad :
1117+ param .data = param .to (torch .float32 )
1118+
10921119 unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
10931120
10941121 if args .train_text_encoder :
@@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15061533 else unet_lora_parameters
15071534 )
15081535 accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1536+
15091537 optimizer .step ()
15101538 lr_scheduler .step ()
15111539 optimizer .zero_grad ()
0 commit comments