@@ -201,6 +201,32 @@ def get_dummy_tokens(self):
201201 prepared_inputs ["input_ids" ] = inputs
202202 return prepared_inputs
203203
204+ def _get_lora_state_dicts (self , modules_to_save ):
205+ state_dicts = {}
206+ for module_name , module in modules_to_save .items ():
207+ if module is not None :
208+ state_dicts [f"{ module_name } _lora_layers" ] = get_peft_model_state_dict (module )
209+ return state_dicts
210+
211+ def _get_modules_to_save (self , pipe , has_denoiser = False ):
212+ modules_to_save = {}
213+ lora_loadable_modules = self .pipeline_class ._lora_loadable_modules
214+
215+ if "text_encoder" in lora_loadable_modules and hasattr (pipe , "text_encoder" ):
216+ modules_to_save ["text_encoder" ] = pipe .text_encoder
217+
218+ if "text_encoder_2" in lora_loadable_modules and hasattr (pipe , "text_encoder_2" ):
219+ modules_to_save ["text_encoder_2" ] = pipe .text_encoder_2
220+
221+ if has_denoiser :
222+ if "unet" in lora_loadable_modules and hasattr (pipe , "unet" ):
223+ modules_to_save ["unet" ] = pipe .unet
224+
225+ if "transformer" in lora_loadable_modules and hasattr (pipe , "transformer" ):
226+ modules_to_save ["transformer" ] = pipe .transformer
227+
228+ return modules_to_save
229+
204230 def test_simple_inference (self ):
205231 """
206232 Tests a simple inference and makes sure it works as expected
@@ -420,45 +446,21 @@ def test_simple_inference_with_text_lora_save_load(self):
420446 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
421447
422448 with tempfile .TemporaryDirectory () as tmpdirname :
423- text_encoder_state_dict = get_peft_model_state_dict (pipe .text_encoder )
424- if self .has_two_text_encoders or self .has_three_text_encoders :
425- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
426- text_encoder_2_state_dict = get_peft_model_state_dict (pipe .text_encoder_2 )
427-
428- self .pipeline_class .save_lora_weights (
429- save_directory = tmpdirname ,
430- text_encoder_lora_layers = text_encoder_state_dict ,
431- text_encoder_2_lora_layers = text_encoder_2_state_dict ,
432- safe_serialization = False ,
433- )
434- else :
435- self .pipeline_class .save_lora_weights (
436- save_directory = tmpdirname ,
437- text_encoder_lora_layers = text_encoder_state_dict ,
438- safe_serialization = False ,
439- )
449+ modules_to_save = self ._get_modules_to_save (pipe )
450+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
440451
441- if self .has_two_text_encoders :
442- if "text_encoder_2" not in self .pipeline_class ._lora_loadable_modules :
443- self .pipeline_class .save_lora_weights (
444- save_directory = tmpdirname ,
445- text_encoder_lora_layers = text_encoder_state_dict ,
446- safe_serialization = False ,
447- )
452+ self .pipeline_class .save_lora_weights (
453+ save_directory = tmpdirname , safe_serialization = False , ** lora_state_dicts
454+ )
448455
449456 self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
450457 pipe .unload_lora_weights ()
451-
452458 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
453459
454- images_lora_from_pretrained = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
455- self .assertTrue (check_if_lora_correctly_set (pipe . text_encoder ), "Lora not correctly set in text encoder " )
460+ for module_name , module in modules_to_save . items ():
461+ self .assertTrue (check_if_lora_correctly_set (module ), f "Lora not correctly set in { module_name } " )
456462
457- if self .has_two_text_encoders or self .has_three_text_encoders :
458- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
459- self .assertTrue (
460- check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
461- )
463+ images_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
462464
463465 self .assertTrue (
464466 np .allclose (images_lora , images_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
@@ -614,54 +616,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
614616 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
615617
616618 with tempfile .TemporaryDirectory () as tmpdirname :
617- text_encoder_state_dict = (
618- get_peft_model_state_dict ( pipe . text_encoder )
619- if "text_encoder" in self .pipeline_class ._lora_loadable_modules
620- else None
619+ modules_to_save = self . _get_modules_to_save ( pipe , has_denoiser = True )
620+ lora_state_dicts = self . _get_lora_state_dicts ( modules_to_save )
621+ self .pipeline_class .save_lora_weights (
622+ save_directory = tmpdirname , safe_serialization = False , ** lora_state_dicts
621623 )
622624
623- denoiser_state_dict = get_peft_model_state_dict (denoiser )
624-
625- saving_kwargs = {
626- "save_directory" : tmpdirname ,
627- "safe_serialization" : False ,
628- }
629-
630- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
631- saving_kwargs .update ({"text_encoder_lora_layers" : text_encoder_state_dict })
632-
633- if self .unet_kwargs is not None :
634- saving_kwargs .update ({"unet_lora_layers" : denoiser_state_dict })
635- else :
636- saving_kwargs .update ({"transformer_lora_layers" : denoiser_state_dict })
637-
638- if self .has_two_text_encoders or self .has_three_text_encoders :
639- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
640- text_encoder_2_state_dict = get_peft_model_state_dict (pipe .text_encoder_2 )
641- saving_kwargs .update ({"text_encoder_2_lora_layers" : text_encoder_2_state_dict })
642-
643- self .pipeline_class .save_lora_weights (** saving_kwargs )
644-
645625 self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
646626 pipe .unload_lora_weights ()
647-
648627 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
649628
650- images_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
651-
652- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
653- self .assertTrue (
654- check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
655- )
656-
657- self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser" )
658-
659- if self .has_two_text_encoders or self .has_three_text_encoders :
660- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
661- self .assertTrue (
662- check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
663- )
629+ for module_name , module in modules_to_save .items ():
630+ self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
664631
632+ images_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
665633 self .assertTrue (
666634 np .allclose (images_lora , images_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
667635 "Loading from saved checkpoints should give same results." ,
0 commit comments