Skip to content

Commit 3aba99a

Browse files
younesbelkadasayakpaulBenjaminBossan
authored
[Peft / Lora] Add adapter_names in fuse_lora (#5823)
* add adapter_name in fuse * add tesrt * up * fix CI * adapt from suggestion * Update src/diffusers/utils/testing_utils.py Co-authored-by: Benjamin Bossan <[email protected]> * change to `require_peft_version_greater` * change variable names in test * Update src/diffusers/loaders/lora.py Co-authored-by: Benjamin Bossan <[email protected]> * break into 2 lines * final comments --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 6683f97 commit 3aba99a

File tree

5 files changed

+173
-11
lines changed

5 files changed

+173
-11
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,26 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima
183183
# Gets the Unet back to the original state
184184
pipe.unfuse_lora()
185185
```
186+
187+
You can also fuse some adapters using `adapter_names` for faster generation:
188+
189+
```py
190+
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
191+
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
192+
193+
pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0])
194+
# Fuses the LoRAs into the Unet
195+
pipe.fuse_lora(adapter_names=["pixel"])
196+
197+
prompt = "a hacker with a hoodie, pixel art"
198+
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
199+
200+
# Gets the Unet back to the original state
201+
pipe.unfuse_lora()
202+
203+
# Fuse all adapters
204+
pipe.fuse_lora(adapter_names=["pixel", "toy"])
205+
206+
prompt = "toy_face of a hacker with a hoodie, pixel art"
207+
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
208+
```

src/diffusers/loaders/lora.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
from contextlib import nullcontext
1617
from typing import Callable, Dict, List, Optional, Union
@@ -1001,6 +1002,7 @@ def fuse_lora(
10011002
fuse_text_encoder: bool = True,
10021003
lora_scale: float = 1.0,
10031004
safe_fusing: bool = False,
1005+
adapter_names: Optional[List[str]] = None,
10041006
):
10051007
r"""
10061008
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
@@ -1020,6 +1022,21 @@ def fuse_lora(
10201022
Controls how much to influence the outputs with the LoRA parameters.
10211023
safe_fusing (`bool`, defaults to `False`):
10221024
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1025+
adapter_names (`List[str]`, *optional*):
1026+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1027+
1028+
Example:
1029+
1030+
```py
1031+
from diffusers import DiffusionPipeline
1032+
import torch
1033+
1034+
pipeline = DiffusionPipeline.from_pretrained(
1035+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1036+
).to("cuda")
1037+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1038+
pipeline.fuse_lora(lora_scale=0.7)
1039+
```
10231040
"""
10241041
if fuse_unet or fuse_text_encoder:
10251042
self.num_fused_loras += 1
@@ -1030,24 +1047,43 @@ def fuse_lora(
10301047

10311048
if fuse_unet:
10321049
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
1033-
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
1050+
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
10341051

10351052
if USE_PEFT_BACKEND:
10361053
from peft.tuners.tuners_utils import BaseTunerLayer
10371054

1038-
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
1039-
# TODO(Patrick, Younes): enable "safe" fusing
1055+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
1056+
merge_kwargs = {"safe_merge": safe_fusing}
1057+
10401058
for module in text_encoder.modules():
10411059
if isinstance(module, BaseTunerLayer):
10421060
if lora_scale != 1.0:
10431061
module.scale_layer(lora_scale)
10441062

1045-
module.merge()
1063+
# For BC with previous PEFT versions, we need to check the signature
1064+
# of the `merge` method to see if it supports the `adapter_names` argument.
1065+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
1066+
if "adapter_names" in supported_merge_kwargs:
1067+
merge_kwargs["adapter_names"] = adapter_names
1068+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
1069+
raise ValueError(
1070+
"The `adapter_names` argument is not supported with your PEFT version. "
1071+
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
1072+
)
1073+
1074+
module.merge(**merge_kwargs)
10461075

10471076
else:
10481077
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
10491078

1050-
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
1079+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
1080+
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
1081+
raise ValueError(
1082+
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
1083+
"backend to use this argument by installing latest PEFT and transformers."
1084+
" `pip install -U peft transformers`"
1085+
)
1086+
10511087
for _, attn_module in text_encoder_attn_modules(text_encoder):
10521088
if isinstance(attn_module.q_proj, PatchedLoraProjection):
10531089
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
@@ -1062,9 +1098,9 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
10621098

10631099
if fuse_text_encoder:
10641100
if hasattr(self, "text_encoder"):
1065-
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
1101+
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
10661102
if hasattr(self, "text_encoder_2"):
1067-
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
1103+
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
10681104

10691105
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
10701106
r"""

src/diffusers/loaders/unet.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
from collections import defaultdict
1617
from contextlib import nullcontext
18+
from functools import partial
1719
from typing import Callable, Dict, List, Optional, Union
1820

1921
import safetensors
@@ -504,22 +506,43 @@ def save_function(weights, filename):
504506
save_function(state_dict, os.path.join(save_directory, weight_name))
505507
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
506508

507-
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
509+
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
508510
self.lora_scale = lora_scale
509511
self._safe_fusing = safe_fusing
510-
self.apply(self._fuse_lora_apply)
512+
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
511513

512-
def _fuse_lora_apply(self, module):
514+
def _fuse_lora_apply(self, module, adapter_names=None):
513515
if not USE_PEFT_BACKEND:
514516
if hasattr(module, "_fuse_lora"):
515517
module._fuse_lora(self.lora_scale, self._safe_fusing)
518+
519+
if adapter_names is not None:
520+
raise ValueError(
521+
"The `adapter_names` argument is not supported in your environment. Please switch"
522+
" to PEFT backend to use this argument by installing latest PEFT and transformers."
523+
" `pip install -U peft transformers`"
524+
)
516525
else:
517526
from peft.tuners.tuners_utils import BaseTunerLayer
518527

528+
merge_kwargs = {"safe_merge": self._safe_fusing}
529+
519530
if isinstance(module, BaseTunerLayer):
520531
if self.lora_scale != 1.0:
521532
module.scale_layer(self.lora_scale)
522-
module.merge(safe_merge=self._safe_fusing)
533+
534+
# For BC with prevous PEFT versions, we need to check the signature
535+
# of the `merge` method to see if it supports the `adapter_names` argument.
536+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
537+
if "adapter_names" in supported_merge_kwargs:
538+
merge_kwargs["adapter_names"] = adapter_names
539+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
540+
raise ValueError(
541+
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
542+
" to the latest version of PEFT. `pip install -U peft`"
543+
)
544+
545+
module.merge(**merge_kwargs)
523546

524547
def unfuse_lora(self):
525548
self.apply(self._unfuse_lora_apply)

src/diffusers/utils/testing_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,23 @@ def require_peft_backend(test_case):
300300
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
301301

302302

303+
def require_peft_version_greater(peft_version):
304+
"""
305+
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
306+
versions of PEFT and transformers.
307+
"""
308+
309+
def decorator(test_case):
310+
correct_peft_version = is_peft_available() and version.parse(
311+
version.parse(importlib.metadata.version("peft")).base_version
312+
) > version.parse(peft_version)
313+
return unittest.skipUnless(
314+
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
315+
)(test_case)
316+
317+
return decorator
318+
319+
303320
def deprecate_after_peft_backend(test_case):
304321
"""
305322
Decorator marking a test that will be skipped after PEFT backend

tests/lora/test_lora_layers_peft.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
nightly,
5151
numpy_cosine_similarity_distance,
5252
require_peft_backend,
53+
require_peft_version_greater,
5354
require_torch_gpu,
5455
slow,
5556
torch_device,
@@ -1105,6 +1106,68 @@ def test_get_list_adapters(self):
11051106
{"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
11061107
)
11071108

1109+
@require_peft_version_greater(peft_version="0.6.2")
1110+
def test_simple_inference_with_text_lora_unet_fused_multi(self):
1111+
"""
1112+
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
1113+
and makes sure it works as expected - with unet and multi-adapter case
1114+
"""
1115+
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1116+
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
1117+
pipe = self.pipeline_class(**components)
1118+
pipe = pipe.to(self.torch_device)
1119+
pipe.set_progress_bar_config(disable=None)
1120+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1121+
1122+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
1123+
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
1124+
1125+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
1126+
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
1127+
1128+
# Attach a second adapter
1129+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1130+
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
1131+
1132+
self.assertTrue(
1133+
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
1134+
)
1135+
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
1136+
1137+
if self.has_two_text_encoders:
1138+
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
1139+
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
1140+
self.assertTrue(
1141+
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
1142+
)
1143+
1144+
# set them to multi-adapter inference mode
1145+
pipe.set_adapters(["adapter-1", "adapter-2"])
1146+
ouputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
1147+
1148+
pipe.set_adapters(["adapter-1"])
1149+
ouputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
1150+
1151+
pipe.fuse_lora(adapter_names=["adapter-1"])
1152+
1153+
# Fusing should still keep the LoRA layers so outpout should remain the same
1154+
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
1155+
1156+
self.assertTrue(
1157+
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3),
1158+
"Fused lora should not change the output",
1159+
)
1160+
1161+
pipe.unfuse_lora()
1162+
pipe.fuse_lora(adapter_names=["adapter-2", "adapter-1"])
1163+
1164+
# Fusing should still keep the LoRA layers
1165+
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
1166+
self.assertTrue(
1167+
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3),
1168+
"Fused lora should not change the output",
1169+
)
1170+
11081171
@unittest.skip("This is failing for now - need to investigate")
11091172
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
11101173
"""

0 commit comments

Comments
 (0)