@@ -658,12 +658,12 @@ def test_simple_inference_with_partial_text_lora(self):
658658                "Removing adapters should change the output" ,
659659            )
660660
661-     def  test_simple_inference_save_pretrained (self ):
661+     def  test_simple_inference_save_pretrained_with_text_lora (self ):
662662        """ 
663663        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained 
664664        """ 
665665        for  scheduler_cls  in  self .scheduler_classes :
666-             components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (scheduler_cls )
666+             components , text_lora_config , _  =  self .get_dummy_components (scheduler_cls )
667667            pipe  =  self .pipeline_class (** components )
668668            pipe  =  pipe .to (torch_device )
669669            pipe .set_progress_bar_config (disable = None )
@@ -672,7 +672,7 @@ def test_simple_inference_save_pretrained(self):
672672            output_no_lora  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
673673            self .assertTrue (output_no_lora .shape  ==  self .output_shape )
674674
675-             pipe , _  =  self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
675+             pipe , _  =  self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
676676            images_lora  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
677677
678678            with  tempfile .TemporaryDirectory () as  tmpdirname :
@@ -681,14 +681,11 @@ def test_simple_inference_save_pretrained(self):
681681                pipe_from_pretrained  =  self .pipeline_class .from_pretrained (tmpdirname )
682682                pipe_from_pretrained .to (torch_device )
683683
684-             self .assertTrue (
685-                 check_if_lora_correctly_set (pipe_from_pretrained .text_encoder ),
686-                 "Lora not correctly set in text encoder" ,
687-             )
688- 
689-             denoiser  =  pipe .transformer  if  self .unet_kwargs  is  None  else  pipe .unet 
690-             denoiser .add_adapter (denoiser_lora_config )
691-             self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
684+             if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
685+                 self .assertTrue (
686+                     check_if_lora_correctly_set (pipe_from_pretrained .text_encoder ),
687+                     "Lora not correctly set in text encoder" ,
688+                 )
692689
693690            if  self .has_two_text_encoders  or  self .has_three_text_encoders :
694691                if  "text_encoder_2"  in  self .pipeline_class ._lora_loadable_modules :
@@ -988,46 +985,53 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
988985            )
989986
990987    def  test_wrong_adapter_name_raises_error (self ):
988+         adapter_name  =  "adapter-1" 
989+ 
991990        scheduler_cls  =  self .scheduler_classes [0 ]
992991        components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (scheduler_cls )
993992        pipe  =  self .pipeline_class (** components )
994993        pipe  =  pipe .to (torch_device )
995994        pipe .set_progress_bar_config (disable = None )
996995        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
997996
998-         pipe , _  =  self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
997+         pipe , _  =  self .check_if_adapters_added_correctly (
998+             pipe , text_lora_config , denoiser_lora_config , adapter_name = adapter_name 
999+         )
9991000
10001001        with  self .assertRaises (ValueError ) as  err_context :
10011002            pipe .set_adapters ("test" )
10021003
10031004        self .assertTrue ("not in the list of present adapters"  in  str (err_context .exception ))
10041005
10051006        # test this works. 
1006-         pipe .set_adapters ("adapter-1" )
1007+         pipe .set_adapters (adapter_name )
10071008        _  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
10081009
10091010    def  test_multiple_wrong_adapter_name_raises_error (self ):
1011+         adapter_name  =  "adapter-1" 
10101012        scheduler_cls  =  self .scheduler_classes [0 ]
10111013        components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (scheduler_cls )
10121014        pipe  =  self .pipeline_class (** components )
10131015        pipe  =  pipe .to (torch_device )
10141016        pipe .set_progress_bar_config (disable = None )
10151017        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
10161018
1017-         pipe , _  =  self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1019+         pipe , _  =  self .check_if_adapters_added_correctly (
1020+             pipe , text_lora_config , denoiser_lora_config , adapter_name = adapter_name 
1021+         )
10181022
10191023        scale_with_wrong_components  =  {"foo" : 0.0 , "bar" : 0.0 , "tik" : 0.0 }
10201024        logger  =  logging .get_logger ("diffusers.loaders.lora_base" )
10211025        logger .setLevel (30 )
10221026        with  CaptureLogger (logger ) as  cap_logger :
1023-             pipe .set_adapters ("adapter-1" , adapter_weights = scale_with_wrong_components )
1027+             pipe .set_adapters (adapter_name , adapter_weights = scale_with_wrong_components )
10241028
10251029        wrong_components  =  sorted (set (scale_with_wrong_components .keys ()))
10261030        msg  =  f"The following components in `adapter_weights` are not part of the pipeline: { wrong_components }  
10271031        self .assertTrue (msg  in  str (cap_logger .out ))
10281032
10291033        # test this works. 
1030-         pipe .set_adapters ("adapter-1" )
1034+         pipe .set_adapters (adapter_name )
10311035        _  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
10321036
10331037    def  test_simple_inference_with_text_denoiser_block_scale (self ):
0 commit comments