@@ -63,6 +63,9 @@ def _maybe_adjust_config(config):
6363    method removes the ambiguity by following what is described here: 
6464    https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. 
6565    """ 
66+     # Track keys that have been explicitly removed to prevent re-adding them. 
67+     deleted_keys  =  set ()
68+ 
6669    rank_pattern  =  config ["rank_pattern" ].copy ()
6770    target_modules  =  config ["target_modules" ]
6871    original_r  =  config ["r" ]
@@ -80,21 +83,22 @@ def _maybe_adjust_config(config):
8083        ambiguous_key  =  key 
8184
8285        if  exact_matches  and  substring_matches :
83-             # if ambiguous we  update the rank associated with the ambiguous key (`proj_out`, for example) 
86+             # if ambiguous,  update the rank associated with the ambiguous key (`proj_out`, for example) 
8487            config ["r" ] =  key_rank 
85-             # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead  
88+             # remove the ambiguous key from `rank_pattern` and record it as deleted  
8689            del  config ["rank_pattern" ][key ]
90+             deleted_keys .add (key )
91+             # For substring matches, add them with the original rank only if they haven't been assigned already 
8792            for  mod  in  substring_matches :
88-                 # avoid overwriting if the module already has a specific rank 
89-                 if  mod  not  in config ["rank_pattern" ]:
93+                 if  mod  not  in config ["rank_pattern" ] and  mod  not  in deleted_keys :
9094                    config ["rank_pattern" ][mod ] =  original_r 
9195
92-             # update  the rest of the keys  with the `original_r`  
96+             # Update  the rest of the target modules  with the original rank if not already set and not deleted  
9397            for  mod  in  target_modules :
94-                 if  mod  !=  ambiguous_key  and  mod  not  in config ["rank_pattern" ]:
98+                 if  mod  !=  ambiguous_key  and  mod  not  in config ["rank_pattern" ]  and   mod   not   in   deleted_keys :
9599                    config ["rank_pattern" ][mod ] =  original_r 
96100
97-     # handle  alphas to deal with cases like 
101+     # Handle  alphas to deal with cases like:  
98102    # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 
99103    has_different_ranks  =  len (config ["rank_pattern" ]) >  1  and  list (config ["rank_pattern" ])[0 ] !=  config ["r" ]
100104    if  has_different_ranks :
@@ -187,6 +191,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
187191        from  peft  import  LoraConfig , inject_adapter_in_model , set_peft_model_state_dict 
188192        from  peft .tuners .tuners_utils  import  BaseTunerLayer 
189193
194+         try :
195+             from  peft .utils .constants  import  FULLY_QUALIFIED_PATTERN_KEY_PREFIX 
196+         except  ImportError :
197+             FULLY_QUALIFIED_PATTERN_KEY_PREFIX  =  None 
198+ 
190199        cache_dir  =  kwargs .pop ("cache_dir" , None )
191200        force_download  =  kwargs .pop ("force_download" , False )
192201        proxies  =  kwargs .pop ("proxies" , None )
@@ -251,14 +260,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
251260                # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. 
252261                # Bias layers in LoRA only have a single dimension 
253262                if  "lora_B"  in  key  and  val .ndim  >  1 :
254-                     rank [key ] =  val .shape [1 ]
263+                     # Support to handle cases where layer patterns are treated as full layer names 
264+                     # was added later in PEFT. So, we handle it accordingly. 
265+                     # TODO: when we fix the minimal PEFT version for Diffusers, 
266+                     # we should remove `_maybe_adjust_config()`. 
267+                     if  FULLY_QUALIFIED_PATTERN_KEY_PREFIX :
268+                         rank [f"{ FULLY_QUALIFIED_PATTERN_KEY_PREFIX } { key }  ] =  val .shape [1 ]
269+                     else :
270+                         rank [key ] =  val .shape [1 ]
255271
256272            if  network_alphas  is  not None  and  len (network_alphas ) >=  1 :
257273                alpha_keys  =  [k  for  k  in  network_alphas .keys () if  k .startswith (f"{ prefix }  )]
258274                network_alphas  =  {k .replace (f"{ prefix }  , "" ): v  for  k , v  in  network_alphas .items () if  k  in  alpha_keys }
259275
260276            lora_config_kwargs  =  get_peft_kwargs (rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict )
261-             lora_config_kwargs  =  _maybe_adjust_config (lora_config_kwargs )
277+             if  not  FULLY_QUALIFIED_PATTERN_KEY_PREFIX :
278+                 lora_config_kwargs  =  _maybe_adjust_config (lora_config_kwargs )
262279
263280            if  "use_dora"  in  lora_config_kwargs :
264281                if  lora_config_kwargs ["use_dora" ]:
0 commit comments