Skip to content

Commit faa6ddd

Browse files
committed
improve lora fusion tests
1 parent 31c4f24 commit faa6ddd

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ def fuse_lora(
607607
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
608608
)
609609

610-
self.num_fused_loras += 1
610+
if adapter_names is None:
611+
self.num_fused_loras += 1
612+
elif isinstance(adapter_names, list):
613+
self.num_fused_loras += len(adapter_names)
611614

612615
def unfuse_lora(self, components: List[str] = [], **kwargs):
613616
r"""
@@ -663,7 +666,7 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
663666
if isinstance(module, BaseTunerLayer):
664667
module.unmerge()
665668

666-
self.num_fused_loras -= 1
669+
self.num_fused_loras = 0
667670

668671
def set_adapters(
669672
self,

tests/lora/utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,9 +1001,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
10011001
)
10021002

10031003
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
1004+
assert pipe.num_fused_loras == 1
10041005
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
10051006

10061007
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
1008+
assert pipe.num_fused_loras == 0
10071009
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
10081010

10091011
# unloading should remove the LoRA layers
@@ -1740,26 +1742,21 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17401742
self.assertTrue(
17411743
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
17421744
)
1745+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
17431746

17441747
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
17451748
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1746-
1747-
# Attach a second adapter
1748-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
1749-
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1750-
1751-
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1752-
17531749
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1750+
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
17541751

17551752
if self.has_two_text_encoders or self.has_three_text_encoders:
17561753
lora_loadable_components = self.pipeline_class._lora_loadable_modules
17571754
if "text_encoder_2" in lora_loadable_components:
17581755
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
1759-
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
17601756
self.assertTrue(
17611757
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
17621758
)
1759+
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
17631760

17641761
# set them to multi-adapter inference mode
17651762
pipe.set_adapters(["adapter-1", "adapter-2"])
@@ -1769,6 +1766,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17691766
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
17701767

17711768
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1769+
assert pipe.num_fused_loras == 1
17721770

17731771
# Fusing should still keep the LoRA layers so outpout should remain the same
17741772
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1779,9 +1777,24 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17791777
)
17801778

17811779
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
1780+
assert pipe.num_fused_loras == 0
1781+
1782+
# unloading should remove the LoRA layers
1783+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
1784+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
1785+
1786+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
1787+
1788+
if self.has_two_text_encoders or self.has_three_text_encoders:
1789+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
1790+
self.assertTrue(
1791+
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
1792+
)
1793+
17821794
pipe.fuse_lora(
17831795
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
17841796
)
1797+
assert pipe.num_fused_loras == 2
17851798

17861799
# Fusing should still keep the LoRA layers
17871800
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]

0 commit comments

Comments
 (0)