5454    "SanaTransformer2DModel" : lambda  model_cls , weights : weights ,
5555    "Lumina2Transformer2DModel" : lambda  model_cls , weights : weights ,
5656}
57- _NO_CONFIG_UPDATE_KEYS  =  ["to_k" , "to_q" , "to_v" ]
5857
5958
6059def  _maybe_adjust_config (config ):
@@ -64,40 +63,38 @@ def _maybe_adjust_config(config):
6463    method removes the ambiguity by following what is described here: 
6564    https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. 
6665    """ 
66+     # Track keys that have been explicitly removed to prevent re-adding them. 
67+     deleted_keys  =  set ()
68+ 
6769    rank_pattern  =  config ["rank_pattern" ].copy ()
6870    target_modules  =  config ["target_modules" ]
6971    original_r  =  config ["r" ]
7072
7173    for  key  in  list (rank_pattern .keys ()):
72-         if  any (prefix  in  key  for  prefix  in  _NO_CONFIG_UPDATE_KEYS ):
73-             continue 
7474        key_rank  =  rank_pattern [key ]
7575
7676        # try to detect ambiguity 
77-         # `target_modules` can also be a str, in which case this loop would loop 
78-         # over the chars of the str. The technically correct way to match LoRA keys 
79-         # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). 
80-         # But this cuts it for now. 
8177        exact_matches  =  [mod  for  mod  in  target_modules  if  mod  ==  key ]
8278        substring_matches  =  [mod  for  mod  in  target_modules  if  key  in  mod  and  mod  !=  key ]
8379        ambiguous_key  =  key 
8480
8581        if  exact_matches  and  substring_matches :
86-             # if ambiguous we  update the rank associated with the ambiguous key (`proj_out`, for example) 
82+             # if ambiguous,  update the rank associated with the ambiguous key (`proj_out`, for example) 
8783            config ["r" ] =  key_rank 
88-             # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead  
84+             # remove the ambiguous key from `rank_pattern` and record it as deleted  
8985            del  config ["rank_pattern" ][key ]
86+             deleted_keys .add (key )
87+             # For substring matches, add them with the original rank only if they haven't been assigned already 
9088            for  mod  in  substring_matches :
91-                 # avoid overwriting if the module already has a specific rank 
92-                 if  mod  not  in config ["rank_pattern" ]:
89+                 if  mod  not  in config ["rank_pattern" ] and  mod  not  in deleted_keys :
9390                    config ["rank_pattern" ][mod ] =  original_r 
9491
95-             # update  the rest of the keys  with the `original_r`  
92+             # Update  the rest of the target modules  with the original rank if not already set and not deleted  
9693            for  mod  in  target_modules :
97-                 if  mod  !=  ambiguous_key  and  mod  not  in config ["rank_pattern" ]:
94+                 if  mod  !=  ambiguous_key  and  mod  not  in config ["rank_pattern" ]  and   mod   not   in   deleted_keys :
9895                    config ["rank_pattern" ][mod ] =  original_r 
9996
100-     # handle  alphas to deal with cases like 
97+     # Handle  alphas to deal with cases like:  
10198    # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 
10299    has_different_ranks  =  len (config ["rank_pattern" ]) >  1  and  list (config ["rank_pattern" ])[0 ] !=  config ["r" ]
103100    if  has_different_ranks :
0 commit comments