3737from huggingface_hub import create_repo , upload_folder
3838from packaging import version
3939from peft import LoraConfig
40- from peft .utils import get_peft_model_state_dict
40+ from peft .utils import get_peft_model_state_dict , set_peft_model_state_dict
4141from torchvision import transforms
4242from tqdm .auto import tqdm
4343from transformers import CLIPTextModel , CLIPTokenizer
4646from diffusers import AutoencoderKL , DDPMScheduler , DiffusionPipeline , StableDiffusionPipeline , UNet2DConditionModel
4747from diffusers .optimization import get_scheduler
4848from diffusers .training_utils import cast_training_params , compute_snr
49- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
49+ from diffusers .utils import (
50+ check_min_version ,
51+ convert_state_dict_to_diffusers ,
52+ convert_unet_state_dict_to_peft ,
53+ is_wandb_available ,
54+ )
5055from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
5156from diffusers .utils .import_utils import is_xformers_available
5257from diffusers .utils .torch_utils import is_compiled_module
@@ -708,6 +713,56 @@ def collate_fn(examples):
708713 num_workers = args .dataloader_num_workers ,
709714 )
710715
716+ def save_model_hook (models , weights , output_dir ):
717+ if accelerator .is_main_process :
718+ unet_lora_layers_to_save = None
719+
720+ for model in models :
721+ if isinstance (model , type (unwrap_model (unet ))):
722+ unet_lora_layers_to_save = get_peft_model_state_dict (model )
723+ else :
724+ raise ValueError (f"Unexpected save model: { model .__class__ } " )
725+
726+ # make sure to pop weight so that corresponding model is not saved again
727+ weights .pop ()
728+
729+ StableDiffusionPipeline .save_lora_weights (
730+ save_directory = output_dir ,
731+ unet_lora_layers = unet_lora_layers_to_save ,
732+ safe_serialization = True ,
733+ )
734+
735+ def load_model_hook (models , input_dir ):
736+ unet_ = None
737+
738+ while len (models ) > 0 :
739+ model = models .pop ()
740+ if isinstance (model , type (unwrap_model (unet ))):
741+ unet_ = model
742+ else :
743+ raise ValueError (f"unexpected save model: { model .__class__ } " )
744+
745+ # returns a tuple of state dictionary and network alphas
746+ lora_state_dict , network_alphas = StableDiffusionPipeline .lora_state_dict (input_dir )
747+
748+ unet_state_dict = {f"{ k .replace ('unet.' , '' )} " : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
749+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
750+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
751+
752+ if incompatible_keys is not None :
753+ # check only for unexpected keys
754+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
755+ # throw warning if some unexpected keys are found and continue loading
756+ if unexpected_keys :
757+ logger .warning (
758+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
759+ f" { unexpected_keys } . "
760+ )
761+
762+ # Make sure the trainable params are in float32
763+ if args .mixed_precision in ["fp16" ]:
764+ cast_training_params ([unet_ ], dtype = torch .float32 )
765+
711766 # Scheduler and math around the number of training steps.
712767 # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
713768 num_warmup_steps_for_scheduler = args .lr_warmup_steps * accelerator .num_processes
@@ -732,6 +787,10 @@ def collate_fn(examples):
732787 unet , optimizer , train_dataloader , lr_scheduler
733788 )
734789
790+ # Register the hooks for efficient saving and loading of LoRA weights
791+ accelerator .register_save_state_pre_hook (save_model_hook )
792+ accelerator .register_load_state_pre_hook (load_model_hook )
793+
735794 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
736795 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
737796 if args .max_train_steps is None :
@@ -906,17 +965,6 @@ def collate_fn(examples):
906965 save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
907966 accelerator .save_state (save_path )
908967
909- unwrapped_unet = unwrap_model (unet )
910- unet_lora_state_dict = convert_state_dict_to_diffusers (
911- get_peft_model_state_dict (unwrapped_unet )
912- )
913-
914- StableDiffusionPipeline .save_lora_weights (
915- save_directory = save_path ,
916- unet_lora_layers = unet_lora_state_dict ,
917- safe_serialization = True ,
918- )
919-
920968 logger .info (f"Saved state to { save_path } " )
921969
922970 logs = {"step_loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
0 commit comments