Skip to content

Commit d890aa2

Browse files
ENH Improve speed of expanding LoRA scales
Resolves #11816 The following call proved to be a bottleneck when setting a lot of LoRA adapters in diffusers: https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/peft.py#L482 This is because we would repeatedly call unet.state_dict(), even though in the standard case, it is not necessary: https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/unet_loader_utils.py#L55 This PR fixes this by deferring this call, so that it is only run when it's necessary, not earlier.
1 parent 05e7a85 commit d890aa2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/diffusers/loaders/unet_loader_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import copy
1515
from typing import TYPE_CHECKING, Dict, List, Union
1616

17+
from torch import nn
18+
1719
from ..utils import logging
1820

1921

@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
5254
weight_for_adapter,
5355
blocks_with_transformer,
5456
transformer_per_block,
55-
unet.state_dict(),
57+
model=unet,
5658
default_scale=default_scale,
5759
)
5860
for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
6567
scales: Union[float, Dict],
6668
blocks_with_transformer: Dict[str, int],
6769
transformer_per_block: Dict[str, int],
68-
state_dict: None,
70+
model: nn.Module,
6971
default_scale: float = 1.0,
7072
):
7173
"""
@@ -155,7 +157,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
155157
del scales[updown]
156158

157159
for layer in scales.keys():
158-
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
160+
if not any(_translate_into_actual_layer_name(layer) in module for module in model.state_dict().keys()):
159161
raise ValueError(
160162
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
161163
)

0 commit comments

Comments
 (0)