@@ -63,6 +63,9 @@ def _maybe_adjust_config(config):
6363 method removes the ambiguity by following what is described here:
6464 https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
6565 """
66+ # Track keys that have been explicitly removed to prevent re-adding them.
67+ deleted_keys = set ()
68+
6669 rank_pattern = config ["rank_pattern" ].copy ()
6770 target_modules = config ["target_modules" ]
6871 original_r = config ["r" ]
@@ -80,21 +83,22 @@ def _maybe_adjust_config(config):
8083 ambiguous_key = key
8184
8285 if exact_matches and substring_matches :
83- # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
86+ # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
8487 config ["r" ] = key_rank
85- # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
88+ # remove the ambiguous key from `rank_pattern` and record it as deleted
8689 del config ["rank_pattern" ][key ]
90+ deleted_keys .add (key )
91+ # For substring matches, add them with the original rank only if they haven't been assigned already
8792 for mod in substring_matches :
88- # avoid overwriting if the module already has a specific rank
89- if mod not in config ["rank_pattern" ]:
93+ if mod not in config ["rank_pattern" ] and mod not in deleted_keys :
9094 config ["rank_pattern" ][mod ] = original_r
9195
92- # update the rest of the keys with the `original_r`
96+ # Update the rest of the target modules with the original rank if not already set and not deleted
9397 for mod in target_modules :
94- if mod != ambiguous_key and mod not in config ["rank_pattern" ]:
98+ if mod != ambiguous_key and mod not in config ["rank_pattern" ] and mod not in deleted_keys :
9599 config ["rank_pattern" ][mod ] = original_r
96100
97- # handle alphas to deal with cases like
101+ # Handle alphas to deal with cases like:
98102 # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
99103 has_different_ranks = len (config ["rank_pattern" ]) > 1 and list (config ["rank_pattern" ])[0 ] != config ["r" ]
100104 if has_different_ranks :
@@ -187,6 +191,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
187191 from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
188192 from peft .tuners .tuners_utils import BaseTunerLayer
189193
194+ try :
195+ from peft .utils .constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
196+ except ImportError :
197+ FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
198+
190199 cache_dir = kwargs .pop ("cache_dir" , None )
191200 force_download = kwargs .pop ("force_download" , False )
192201 proxies = kwargs .pop ("proxies" , None )
@@ -251,14 +260,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
251260 # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
252261 # Bias layers in LoRA only have a single dimension
253262 if "lora_B" in key and val .ndim > 1 :
254- rank [key ] = val .shape [1 ]
263+ # Support to handle cases where layer patterns are treated as full layer names
264+ # was added later in PEFT. So, we handle it accordingly.
265+ # TODO: when we fix the minimal PEFT version for Diffusers,
266+ # we should remove `_maybe_adjust_config()`.
267+ if FULLY_QUALIFIED_PATTERN_KEY_PREFIX :
268+ rank [f"{ FULLY_QUALIFIED_PATTERN_KEY_PREFIX } { key } " ] = val .shape [1 ]
269+ else :
270+ rank [key ] = val .shape [1 ]
255271
256272 if network_alphas is not None and len (network_alphas ) >= 1 :
257273 alpha_keys = [k for k in network_alphas .keys () if k .startswith (f"{ prefix } ." )]
258274 network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
259275
260276 lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict )
261- lora_config_kwargs = _maybe_adjust_config (lora_config_kwargs )
277+ if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX :
278+ lora_config_kwargs = _maybe_adjust_config (lora_config_kwargs )
262279
263280 if "use_dora" in lora_config_kwargs :
264281 if lora_config_kwargs ["use_dora" ]:
0 commit comments