Skip to content

Commit 908d151

Browse files
committed
fix
1 parent b6ca9d9 commit 908d151

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1981,7 +1981,8 @@ def _load_norm_into_transformer(
19811981
logger.info(
19821982
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
19831983
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
1984-
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
1984+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
1985+
"If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
19851986
)
19861987

19871988
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
@@ -2286,6 +2287,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
22862287
norm_state_dict=None,
22872288
prefix=None,
22882289
) -> bool:
2290+
"""
2291+
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
2292+
generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
2293+
"""
22892294
state_dict = {}
22902295
if lora_state_dict is not None:
22912296
state_dict.update(lora_state_dict)

src/diffusers/loaders/peft.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@
5757

5858

5959
def _maybe_adjust_config(config):
60+
"""
61+
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
62+
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
63+
method removes the ambiguity by following what is described here:
64+
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
65+
"""
6066
rank_pattern = config["rank_pattern"].copy()
6167
target_modules = config["target_modules"]
6268
original_r = config["r"]
@@ -65,6 +71,10 @@ def _maybe_adjust_config(config):
6571
key_rank = rank_pattern[key]
6672

6773
# try to detect ambiguity
74+
# `target_modules` can also be a str, in which case this loop would loop
75+
# over the chars of the str. The technically correct way to match LoRA keys
76+
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
77+
# But this cuts it for now.
6878
exact_matches = [mod for mod in target_modules if mod == key]
6979
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
7080
ambiguous_key = key

src/diffusers/utils/peft_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
181181
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
182182
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
183183
# for now we know that the "bias" keys are only associated with `lora_B`.
184-
lora_bias = any("lora_B" and "bias" in k for k in peft_state_dict)
184+
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
185185

186186
lora_config_kwargs = {
187187
"r": r,

0 commit comments

Comments
 (0)