@@ -644,24 +644,24 @@ def load_lora_weights(
644644 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
645645 state_dict , network_alphas = self .lora_state_dict (
646646 pretrained_model_name_or_path_or_dict ,
647- unet_config = self .unet .config ,
647+ unet_config = self .unet .config if hasattr ( self , "unet" ) else None ,
648648 ** kwargs ,
649649 )
650650
651651 is_correct_format = all ("lora" in key for key in state_dict .keys ())
652652 if not is_correct_format :
653653 raise ValueError ("Invalid LoRA checkpoint." )
654-
655- self .load_lora_into_unet (
656- state_dict ,
657- network_alphas = network_alphas ,
658- unet = self .unet ,
659- adapter_name = adapter_name ,
660- _pipeline = self ,
661- low_cpu_mem_usage = low_cpu_mem_usage ,
662- )
654+ if hasattr ( self , "unet" ):
655+ self .load_lora_into_unet (
656+ state_dict ,
657+ network_alphas = network_alphas ,
658+ unet = self .unet ,
659+ adapter_name = adapter_name ,
660+ _pipeline = self ,
661+ low_cpu_mem_usage = low_cpu_mem_usage ,
662+ )
663663 text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
664- if len (text_encoder_state_dict ) > 0 :
664+ if len (text_encoder_state_dict ) > 0 and hasattr ( self , "text_encoder" ) :
665665 self .load_lora_into_text_encoder (
666666 text_encoder_state_dict ,
667667 network_alphas = network_alphas ,
@@ -674,7 +674,7 @@ def load_lora_weights(
674674 )
675675
676676 text_encoder_2_state_dict = {k : v for k , v in state_dict .items () if "text_encoder_2." in k }
677- if len (text_encoder_2_state_dict ) > 0 :
677+ if len (text_encoder_2_state_dict ) > 0 and hasattr ( self , "text_encoder_2" ) :
678678 self .load_lora_into_text_encoder (
679679 text_encoder_2_state_dict ,
680680 network_alphas = network_alphas ,
0 commit comments