Skip to content

Commit cbc4432

Browse files
committed
finish 2.
1 parent 76f9d82 commit cbc4432

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/diffusers/loaders/peft.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
"SanaTransformer2DModel": lambda model_cls, weights: weights,
5555
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
5656
}
57-
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]
5857

5958

6059
def _maybe_adjust_config(config):
@@ -64,40 +63,38 @@ def _maybe_adjust_config(config):
6463
method removes the ambiguity by following what is described here:
6564
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
6665
"""
66+
# Track keys that have been explicitly removed to prevent re-adding them.
67+
deleted_keys = set()
68+
6769
rank_pattern = config["rank_pattern"].copy()
6870
target_modules = config["target_modules"]
6971
original_r = config["r"]
7072

7173
for key in list(rank_pattern.keys()):
72-
if any(prefix in key for prefix in _NO_CONFIG_UPDATE_KEYS):
73-
continue
7474
key_rank = rank_pattern[key]
7575

7676
# try to detect ambiguity
77-
# `target_modules` can also be a str, in which case this loop would loop
78-
# over the chars of the str. The technically correct way to match LoRA keys
79-
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
80-
# But this cuts it for now.
8177
exact_matches = [mod for mod in target_modules if mod == key]
8278
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
8379
ambiguous_key = key
8480

8581
if exact_matches and substring_matches:
86-
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
82+
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
8783
config["r"] = key_rank
88-
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
84+
# remove the ambiguous key from `rank_pattern` and record it as deleted
8985
del config["rank_pattern"][key]
86+
deleted_keys.add(key)
87+
# For substring matches, add them with the original rank only if they haven't been assigned already
9088
for mod in substring_matches:
91-
# avoid overwriting if the module already has a specific rank
92-
if mod not in config["rank_pattern"]:
89+
if mod not in config["rank_pattern"] and mod not in deleted_keys:
9390
config["rank_pattern"][mod] = original_r
9491

95-
# update the rest of the keys with the `original_r`
92+
# Update the rest of the target modules with the original rank if not already set and not deleted
9693
for mod in target_modules:
97-
if mod != ambiguous_key and mod not in config["rank_pattern"]:
94+
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
9895
config["rank_pattern"][mod] = original_r
9996

100-
# handle alphas to deal with cases like
97+
# Handle alphas to deal with cases like:
10198
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
10299
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
103100
if has_different_ranks:

0 commit comments

Comments
 (0)