Skip to content

Commit 5f8cde6

Browse files
committed
more improvements.
1 parent faa6ddd commit 5f8cde6

File tree

2 files changed

+82
-25
lines changed

2 files changed

+82
-25
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ def fuse_lora(
592592
if len(components) == 0:
593593
raise ValueError("`components` cannot be an empty list.")
594594

595+
merged_adapters = set()
595596
for fuse_component in components:
596597
if fuse_component not in self._lora_loadable_modules:
597598
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -601,16 +602,19 @@ def fuse_lora(
601602
# check if diffusers model
602603
if issubclass(model.__class__, ModelMixin):
603604
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
605+
for module in model.modules():
606+
if isinstance(module, BaseTunerLayer):
607+
merged_adapters.update(set(module.merged_adapters))
604608
# handle transformers models.
605609
if issubclass(model.__class__, PreTrainedModel):
606610
fuse_text_encoder_lora(
607611
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
608612
)
613+
for module in model.modules():
614+
if isinstance(module, BaseTunerLayer):
615+
merged_adapters.update(set(module.merged_adapters))
609616

610-
if adapter_names is None:
611-
self.num_fused_loras += 1
612-
elif isinstance(adapter_names, list):
613-
self.num_fused_loras += len(adapter_names)
617+
self.num_fused_loras += len(merged_adapters)
614618

615619
def unfuse_lora(self, components: List[str] = [], **kwargs):
616620
r"""

tests/lora/utils.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import pytest
2424
import torch
25+
from parameterized import parameterized
2526

2627
from diffusers import (
2728
AutoencoderKL,
@@ -80,6 +81,18 @@ def initialize_dummy_state_dict(state_dict):
8081
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
8182

8283

84+
def determine_attention_kwargs_name(pipeline_class):
85+
call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()
86+
87+
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
88+
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
89+
if possible_attention_kwargs in call_signature_keys:
90+
attention_kwargs_name = possible_attention_kwargs
91+
break
92+
assert attention_kwargs_name is not None
93+
return attention_kwargs_name
94+
95+
8396
@require_peft_backend
8497
class PeftLoraLoaderMixinTests:
8598
pipeline_class = None
@@ -440,14 +453,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
440453
Tests a simple inference with lora attached on the text encoder + scale argument
441454
and makes sure it works as expected
442455
"""
443-
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
444-
445-
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
446-
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
447-
if possible_attention_kwargs in call_signature_keys:
448-
attention_kwargs_name = possible_attention_kwargs
449-
break
450-
assert attention_kwargs_name is not None
456+
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
451457

452458
for scheduler_cls in self.scheduler_classes:
453459
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -803,12 +809,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
803809
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
804810
and makes sure it works as expected
805811
"""
806-
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
807-
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
808-
if possible_attention_kwargs in call_signature_keys:
809-
attention_kwargs_name = possible_attention_kwargs
810-
break
811-
assert attention_kwargs_name is not None
812+
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
812813

813814
for scheduler_cls in self.scheduler_classes:
814815
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1765,7 +1766,8 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
17651766
pipe.set_adapters(["adapter-1"])
17661767
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
17671768

1768-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1769+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
1770+
# pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
17691771
assert pipe.num_fused_loras == 1
17701772

17711773
# Fusing should still keep the LoRA layers so outpout should remain the same
@@ -1803,6 +1805,62 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
18031805
"Fused lora should not change the output",
18041806
)
18051807

1808+
@parameterized.expand([1.0, 0.8])
1809+
def test_lora_scale_kwargs_match_fusion(
1810+
self, lora_scale, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
1811+
):
1812+
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
1813+
1814+
for scheduler_cls in self.scheduler_classes:
1815+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1816+
pipe = self.pipeline_class(**components)
1817+
pipe = pipe.to(torch_device)
1818+
pipe.set_progress_bar_config(disable=None)
1819+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1820+
1821+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1822+
self.assertTrue(output_no_lora.shape == self.output_shape)
1823+
1824+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
1825+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
1826+
self.assertTrue(
1827+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
1828+
)
1829+
1830+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
1831+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1832+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1833+
1834+
if self.has_two_text_encoders or self.has_three_text_encoders:
1835+
lora_loadable_components = self.pipeline_class._lora_loadable_modules
1836+
if "text_encoder_2" in lora_loadable_components:
1837+
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
1838+
self.assertTrue(
1839+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
1840+
)
1841+
1842+
pipe.set_adapters(["adapter-1"])
1843+
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
1844+
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
1845+
1846+
pipe.fuse_lora(
1847+
components=self.pipeline_class._lora_loadable_modules,
1848+
adapter_names=["adapter-1"],
1849+
lora_scale=lora_scale,
1850+
)
1851+
assert pipe.num_fused_loras == 1
1852+
1853+
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1854+
1855+
self.assertTrue(
1856+
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1857+
"Fused lora should not change the output",
1858+
)
1859+
self.assertFalse(
1860+
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
1861+
"LoRA should change the output",
1862+
)
1863+
18061864
@require_peft_version_greater(peft_version="0.9.0")
18071865
def test_simple_inference_with_dora(self):
18081866
for scheduler_cls in self.scheduler_classes:
@@ -2007,12 +2065,7 @@ def test_logs_info_when_no_lora_keys_found(self):
20072065

20082066
def test_set_adapters_match_attention_kwargs(self):
20092067
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
2010-
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
2011-
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
2012-
if possible_attention_kwargs in call_signature_keys:
2013-
attention_kwargs_name = possible_attention_kwargs
2014-
break
2015-
assert attention_kwargs_name is not None
2068+
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
20162069

20172070
for scheduler_cls in self.scheduler_classes:
20182071
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)

0 commit comments

Comments
 (0)