2020
2121MIN_SV = 1e-6
2222
23+ LORA_DOWN_UP_FORMATS = [
24+ ("lora_down" , "lora_up" ), # sd-scripts LoRA
25+ ("lora_A" , "lora_B" ), # PEFT LoRA
26+ ("down" , "up" ), # ControlLoRA
27+ ]
28+
29+
2330# Model save and load functions
2431
2532
@@ -192,24 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
192199
193200
194201def resize_lora_model (lora_sd , new_rank , new_conv_rank , save_dtype , device , dynamic_method , dynamic_param , verbose ):
195- network_alpha = None
196- network_dim = None
202+ max_old_rank = None
203+ new_alpha = None
197204 verbose_str = "\n "
198205 fro_list = []
199206
200- # Extract loaded lora dim and alpha
201- for key , value in lora_sd .items ():
202- if network_alpha is None and "alpha" in key :
203- network_alpha = value
204- if network_dim is None and "lora_down" in key and len (value .size ()) == 2 :
205- network_dim = value .size ()[0 ]
206- if network_alpha is not None and network_dim is not None :
207- break
208- if network_alpha is None :
209- network_alpha = network_dim
210-
211- scale = network_alpha / network_dim
212-
213207 if dynamic_method :
214208 logger .info (
215209 f"Dynamically determining new alphas and dims based off { dynamic_method } : { dynamic_param } , max rank is { new_rank } "
@@ -224,28 +218,47 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
224218
225219 with torch .no_grad ():
226220 for key , value in tqdm (lora_sd .items ()):
227- weight_name = None
228- if "lora_down" in key :
229- block_down_name = key .rsplit (".lora_down" , 1 )[0 ]
230- weight_name = key .rsplit ("." , 1 )[- 1 ]
231- lora_down_weight = value
232- else :
221+ key_parts = key .split ("." )
222+ block_down_name = None
223+ for _format in LORA_DOWN_UP_FORMATS :
224+ # Currently we only match lora_down_name in the last two parts of key
225+ # because ("down", "up") are general words and may appear in block_down_name
226+ if len (key_parts ) >= 2 and _format [0 ] == key_parts [- 2 ]:
227+ block_down_name = "." .join (key_parts [:- 2 ])
228+ lora_down_name = "." + _format [0 ]
229+ lora_up_name = "." + _format [1 ]
230+ weight_name = "." + key_parts [- 1 ]
231+ break
232+ if len (key_parts ) >= 1 and _format [0 ] == key_parts [- 1 ]:
233+ block_down_name = "." .join (key_parts [:- 1 ])
234+ lora_down_name = "." + _format [0 ]
235+ lora_up_name = "." + _format [1 ]
236+ weight_name = ""
237+ break
238+
239+ if block_down_name is None :
240+ # This parameter is not lora_down
233241 continue
234242
235- # find corresponding lora_up and alpha
243+ # Now weight_name can be ".weight" or ""
244+ # Find corresponding lora_up and alpha
236245 block_up_name = block_down_name
237- lora_up_weight = lora_sd .get (block_up_name + ".lora_up." + weight_name , None )
246+ lora_down_weight = value
247+ lora_up_weight = lora_sd .get (block_up_name + lora_up_name + weight_name , None )
238248 lora_alpha = lora_sd .get (block_down_name + ".alpha" , None )
239249
240250 weights_loaded = lora_down_weight is not None and lora_up_weight is not None
241251
242252 if weights_loaded :
243253
244254 conv2d = len (lora_down_weight .size ()) == 4
255+ old_rank = lora_down_weight .size ()[0 ]
256+ max_old_rank = max (max_old_rank or 0 , old_rank )
257+
245258 if lora_alpha is None :
246259 scale = 1.0
247260 else :
248- scale = lora_alpha / lora_down_weight . size ()[ 0 ]
261+ scale = lora_alpha / old_rank
249262
250263 if conv2d :
251264 full_weight_matrix = merge_conv (lora_down_weight , lora_up_weight , device )
@@ -272,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
272285 verbose_str += "\n "
273286
274287 new_alpha = param_dict ["new_alpha" ]
275- o_lora_sd [block_down_name + "." + "lora_down.weight" ] = param_dict ["lora_down" ].to (save_dtype ).contiguous ()
276- o_lora_sd [block_up_name + "." + "lora_up.weight" ] = param_dict ["lora_up" ].to (save_dtype ).contiguous ()
277- o_lora_sd [block_up_name + "." " alpha" ] = torch .tensor (param_dict ["new_alpha" ]).to (save_dtype )
288+ o_lora_sd [block_down_name + lora_down_name + weight_name ] = param_dict ["lora_down" ].to (save_dtype ).contiguous ()
289+ o_lora_sd [block_up_name + lora_up_name + weight_name ] = param_dict ["lora_up" ].to (save_dtype ).contiguous ()
290+ o_lora_sd [block_down_name + ".alpha" ] = torch .tensor (param_dict ["new_alpha" ]).to (save_dtype )
278291
279292 block_down_name = None
280293 block_up_name = None
@@ -287,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
287300 print (verbose_str )
288301 print (f"Average Frobenius norm retention: { np .mean (fro_list ):.2%} | std: { np .std (fro_list ):0.3f} " )
289302 logger .info ("resizing complete" )
290- return o_lora_sd , network_dim , new_alpha
303+ return o_lora_sd , max_old_rank , new_alpha
291304
292305
293306def resize (args ):
0 commit comments