Skip to content

Commit e32cb31

Browse files
committed
updates
1 parent 2e0bf0a commit e32cb31

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,19 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
665665
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
666666
for module in model.modules():
667667
if isinstance(module, BaseTunerLayer):
668+
for adapter in set(module.merged_adapters):
669+
if adapter and adapter in self._merged_adapters:
670+
self._merged_adapters.remove(adapter)
668671
module.unmerge()
669-
self._merged_adapters.update(set(module.merged_adapters))
670672

671673
@property
672674
def num_fused_loras(self):
673675
return len(self._merged_adapters)
674676

677+
@property
678+
def fused_loras(self):
679+
return self._merged_adapters
680+
675681
def set_adapters(
676682
self,
677683
adapter_names: Union[List[str], str],

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
124124
def test_simple_inference_with_text_denoiser_lora_unfused(self):
125125
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
126126

127+
def test_lora_scale_kwargs_match_fusion(self):
128+
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
129+
127130
@unittest.skip("Not supported in CogVideoX.")
128131
def test_simple_inference_with_text_denoiser_block_scale(self):
129132
pass

tests/lora/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -877,11 +877,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
877877
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
878878

879879
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
880-
assert pipe.num_fused_loras == 1
880+
assert pipe.num_fused_loras == 1, pipe.num_fused_loras
881881
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
882882

883883
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
884-
assert pipe.num_fused_loras == 0
884+
assert pipe.num_fused_loras == 0, pipe.num_fused_loras
885885
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
886886

887887
# unloading should remove the LoRA layers
@@ -1623,7 +1623,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16231623
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
16241624

16251625
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1626-
assert pipe.num_fused_loras == 1
1626+
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
16271627

16281628
# Fusing should still keep the LoRA layers so outpout should remain the same
16291629
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1634,9 +1634,8 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16341634
)
16351635

16361636
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
1637-
assert pipe.num_fused_loras == 0
1637+
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
16381638

1639-
# unloading should remove the LoRA layers
16401639
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
16411640
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
16421641

@@ -1651,14 +1650,16 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
16511650
pipe.fuse_lora(
16521651
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
16531652
)
1654-
assert pipe.num_fused_loras == 2
1653+
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
16551654

16561655
# Fusing should still keep the LoRA layers
16571656
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
16581657
self.assertTrue(
16591658
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
16601659
"Fused lora should not change the output",
16611660
)
1661+
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
1662+
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
16621663

16631664
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
16641665
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
@@ -1702,7 +1703,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
17021703
adapter_names=["adapter-1"],
17031704
lora_scale=lora_scale,
17041705
)
1705-
assert pipe.num_fused_loras == 1
1706+
assert pipe.num_fused_loras == 1, pipe.num_fused_loras
17061707

17071708
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
17081709

0 commit comments

Comments
 (0)