@@ -1001,9 +1001,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
10011001 )
10021002
10031003 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules )
1004+ assert pipe .num_fused_loras == 1
10041005 output_fused_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
10051006
10061007 pipe .unfuse_lora (components = self .pipeline_class ._lora_loadable_modules )
1008+ assert pipe .num_fused_loras == 0
10071009 output_unfused_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
10081010
10091011 # unloading should remove the LoRA layers
@@ -1740,26 +1742,21 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17401742 self .assertTrue (
17411743 check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
17421744 )
1745+ pipe .text_encoder .add_adapter (text_lora_config , "adapter-2" )
17431746
17441747 denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
17451748 denoiser .add_adapter (denoiser_lora_config , "adapter-1" )
1746-
1747- # Attach a second adapter
1748- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1749- pipe .text_encoder .add_adapter (text_lora_config , "adapter-2" )
1750-
1751- denoiser .add_adapter (denoiser_lora_config , "adapter-2" )
1752-
17531749 self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
1750+ denoiser .add_adapter (denoiser_lora_config , "adapter-2" )
17541751
17551752 if self .has_two_text_encoders or self .has_three_text_encoders :
17561753 lora_loadable_components = self .pipeline_class ._lora_loadable_modules
17571754 if "text_encoder_2" in lora_loadable_components :
17581755 pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-1" )
1759- pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-2" )
17601756 self .assertTrue (
17611757 check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
17621758 )
1759+ pipe .text_encoder_2 .add_adapter (text_lora_config , "adapter-2" )
17631760
17641761 # set them to multi-adapter inference mode
17651762 pipe .set_adapters (["adapter-1" , "adapter-2" ])
@@ -1769,6 +1766,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17691766 outputs_lora_1 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
17701767
17711768 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , adapter_names = ["adapter-1" ])
1769+ assert pipe .num_fused_loras == 1
17721770
17731771 # Fusing should still keep the LoRA layers so outpout should remain the same
17741772 outputs_lora_1_fused = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -1779,9 +1777,24 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17791777 )
17801778
17811779 pipe .unfuse_lora (components = self .pipeline_class ._lora_loadable_modules )
1780+ assert pipe .num_fused_loras == 0
1781+
1782+ # unloading should remove the LoRA layers
1783+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1784+ self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Unfuse should still keep LoRA layers" )
1785+
1786+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Unfuse should still keep LoRA layers" )
1787+
1788+ if self .has_two_text_encoders or self .has_three_text_encoders :
1789+ if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
1790+ self .assertTrue (
1791+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Unfuse should still keep LoRA layers"
1792+ )
1793+
17821794 pipe .fuse_lora (
17831795 components = self .pipeline_class ._lora_loadable_modules , adapter_names = ["adapter-2" , "adapter-1" ]
17841796 )
1797+ assert pipe .num_fused_loras == 2
17851798
17861799 # Fusing should still keep the LoRA layers
17871800 output_all_lora_fused = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
0 commit comments