2020
2121MIN_SV = 1e-6
2222
23- # Tune layers to various trainer formats.
24- LORAFMT1 = [ "lora_down" , "lora_up" ]
25- LORAFMT2 = [ "lora.down " , "lora.up" ]
26- LORAFMT3 = [ "lora_A " , "lora_B" ]
27- LORAFMT4 = [ "down" , "up" ]
28- LORAFMT = LORAFMT1
23+ LORA_DOWN_UP_FORMATS = [
24+ ( "lora_down" , "lora_up" ), # sd-scripts LoRA
25+ ( "lora_A " , "lora_B" ), # PEFT LoRA
26+ ( "down " , "up" ), # ControlLoRA
27+ ]
28+
2929
3030# Model save and load functions
3131
@@ -97,8 +97,8 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
9797 U = U @ torch .diag (S )
9898 Vh = Vh [:lora_rank , :]
9999
100- param_dict [LORAFMT [ 0 ] ] = Vh .reshape (lora_rank , in_size , kernel_size , kernel_size ).cpu ()
101- param_dict [LORAFMT [ 1 ] ] = U .reshape (out_size , lora_rank , 1 , 1 ).cpu ()
100+ param_dict ["lora_down" ] = Vh .reshape (lora_rank , in_size , kernel_size , kernel_size ).cpu ()
101+ param_dict ["lora_up" ] = U .reshape (out_size , lora_rank , 1 , 1 ).cpu ()
102102 del U , S , Vh , weight
103103 return param_dict
104104
@@ -116,8 +116,8 @@ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, sca
116116 U = U @ torch .diag (S )
117117 Vh = Vh [:lora_rank , :]
118118
119- param_dict [LORAFMT [ 0 ] ] = Vh .reshape (lora_rank , in_size ).cpu ()
120- param_dict [LORAFMT [ 1 ] ] = U .reshape (out_size , lora_rank ).cpu ()
119+ param_dict ["lora_down" ] = Vh .reshape (lora_rank , in_size ).cpu ()
120+ param_dict ["lora_up" ] = U .reshape (out_size , lora_rank ).cpu ()
121121 del U , S , Vh , weight
122122 return param_dict
123123
@@ -199,34 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
199199
200200
201201def resize_lora_model (lora_sd , new_rank , new_conv_rank , save_dtype , device , dynamic_method , dynamic_param , verbose ):
202- global LORAFMT
203- network_alpha = None
204- network_dim = None
202+ max_old_rank = None
203+ new_alpha = None
205204 verbose_str = "\n "
206205 fro_list = []
207206
208- # Extract loaded lora dim and alpha
209- for key , value in lora_sd .items ():
210- if network_alpha is None and "alpha" in key :
211- network_alpha = value
212- if (network_dim is None and len (value .size ()) == 2
213- and (LORAFMT1 [0 ] in key or LORAFMT2 [0 ] in key or LORAFMT3 [0 ] in key or LORAFMT4 [0 ] in key )):
214- if LORAFMT1 [0 ] in key :
215- LORAFMT = LORAFMT1
216- elif LORAFMT2 [0 ] in key :
217- LORAFMT = LORAFMT2
218- elif LORAFMT3 [0 ] in key :
219- LORAFMT = LORAFMT3
220- elif LORAFMT4 [0 ] in key :
221- LORAFMT = LORAFMT4
222- network_dim = value .size ()[0 ]
223- if network_alpha is not None and network_dim is not None :
224- break
225- if network_alpha is None :
226- network_alpha = network_dim
227-
228- scale = network_alpha / network_dim
229-
230207 if dynamic_method :
231208 logger .info (
232209 f"Dynamically determining new alphas and dims based off { dynamic_method } : { dynamic_param } , max rank is { new_rank } "
@@ -241,31 +218,47 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
241218
242219 with torch .no_grad ():
243220 for key , value in tqdm (lora_sd .items ()):
244- weight_name = None
245- if LORAFMT [0 ] in key :
246- block_down_name = key .rsplit (f".{ LORAFMT [0 ]} " , 1 )[0 ]
247- if key .endswith (f".{ LORAFMT [0 ]} " ):
221+ key_parts = key .split ("." )
222+ block_down_name = None
223+ for _format in LORA_DOWN_UP_FORMATS :
224+ # Currently we only match lora_down_name in the last two parts of key
225+ # because ("down", "up") are general words and may appear in block_down_name
226+ if len (key_parts ) >= 2 and _format [0 ] == key_parts [- 2 ]:
227+ block_down_name = "." .join (key_parts [:- 2 ])
228+ lora_down_name = "." + _format [0 ]
229+ lora_up_name = "." + _format [1 ]
230+ weight_name = "." + key_parts [- 1 ]
231+ break
232+ if len (key_parts ) >= 1 and _format [0 ] == key_parts [- 1 ]:
233+ block_down_name = "." .join (key_parts [:- 1 ])
234+ lora_down_name = "." + _format [0 ]
235+ lora_up_name = "." + _format [1 ]
248236 weight_name = ""
249- else :
250- weight_name = key . rsplit ( f". { LORAFMT [ 0 ] } " , 1 )[ - 1 ]
251- lora_down_weight = value
252- else :
237+ break
238+
239+ if block_down_name is None :
240+ # This parameter is not lora_down
253241 continue
254242
255- # find corresponding lora_up and alpha
243+ # Now weight_name can be ".weight" or ""
244+ # Find corresponding lora_up and alpha
256245 block_up_name = block_down_name
257- lora_up_weight = lora_sd .get (block_up_name + f".{ LORAFMT [1 ]} " + weight_name , None )
246+ lora_down_weight = value
247+ lora_up_weight = lora_sd .get (block_up_name + lora_up_name + weight_name , None )
258248 lora_alpha = lora_sd .get (block_down_name + ".alpha" , None )
259249
260250 weights_loaded = lora_down_weight is not None and lora_up_weight is not None
261251
262252 if weights_loaded :
263253
264254 conv2d = len (lora_down_weight .size ()) == 4
255+ old_rank = lora_down_weight .size ()[0 ]
256+ max_old_rank = max (max_old_rank or 0 , old_rank )
257+
265258 if lora_alpha is None :
266259 scale = 1.0
267260 else :
268- scale = lora_alpha / lora_down_weight . size ()[ 0 ]
261+ scale = lora_alpha / old_rank
269262
270263 if conv2d :
271264 full_weight_matrix = merge_conv (lora_down_weight , lora_up_weight , device )
@@ -292,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
292285 verbose_str += "\n "
293286
294287 new_alpha = param_dict ["new_alpha" ]
295- o_lora_sd [block_down_name + f". { LORAFMT [ 0 ] } " + weight_name ] = param_dict [LORAFMT [ 0 ] ].to (save_dtype ).contiguous ()
296- o_lora_sd [block_up_name + f". { LORAFMT [ 1 ] } " + weight_name ] = param_dict [LORAFMT [ 1 ] ].to (save_dtype ).contiguous ()
297- o_lora_sd [block_up_name + ".alpha" ] = torch .tensor (param_dict ["new_alpha" ]).to (save_dtype )
288+ o_lora_sd [block_down_name + lora_down_name + weight_name ] = param_dict ["lora_down" ].to (save_dtype ).contiguous ()
289+ o_lora_sd [block_up_name + lora_up_name + weight_name ] = param_dict ["lora_up" ].to (save_dtype ).contiguous ()
290+ o_lora_sd [block_down_name + ".alpha" ] = torch .tensor (param_dict ["new_alpha" ]).to (save_dtype )
298291
299292 block_down_name = None
300293 block_up_name = None
@@ -307,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
307300 print (verbose_str )
308301 print (f"Average Frobenius norm retention: { np .mean (fro_list ):.2%} | std: { np .std (fro_list ):0.3f} " )
309302 logger .info ("resizing complete" )
310- return o_lora_sd , network_dim , new_alpha
303+ return o_lora_sd , max_old_rank , new_alpha
311304
312305
313306def resize (args ):
0 commit comments