2424LORAFMT1 = ["lora_down" , "lora_up" ]
2525LORAFMT2 = ["lora.down" , "lora.up" ]
2626LORAFMT3 = ["lora_A" , "lora_B" ]
27+ LORAFMT4 = ["down" , "up" ]
2728LORAFMT = LORAFMT1
2829
2930# Model save and load functions
@@ -209,13 +210,15 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
209210 if network_alpha is None and "alpha" in key :
210211 network_alpha = value
211212 if (network_dim is None and len (value .size ()) == 2
212- and (LORAFMT1 [0 ] in key or LORAFMT2 [0 ] in key or LORAFMT3 [0 ] in key )):
213+ and (LORAFMT1 [0 ] in key or LORAFMT2 [0 ] in key or LORAFMT3 [0 ] in key or LORAFMT4 [ 0 ] in key )):
213214 if LORAFMT1 [0 ] in key :
214215 LORAFMT = LORAFMT1
215216 elif LORAFMT2 [0 ] in key :
216217 LORAFMT = LORAFMT2
217218 elif LORAFMT3 [0 ] in key :
218219 LORAFMT = LORAFMT3
220+ elif LORAFMT4 [0 ] in key :
221+ LORAFMT = LORAFMT4
219222 network_dim = value .size ()[0 ]
220223 if network_alpha is not None and network_dim is not None :
221224 break
@@ -241,14 +244,17 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
241244 weight_name = None
242245 if LORAFMT [0 ] in key :
243246 block_down_name = key .rsplit (f".{ LORAFMT [0 ]} " , 1 )[0 ]
244- weight_name = key .rsplit ("." , 1 )[- 1 ]
247+ if key .endswith (f".{ LORAFMT [0 ]} " ):
248+ weight_name = ""
249+ else :
250+ weight_name = key .rsplit (f".{ LORAFMT [0 ]} " , 1 )[- 1 ]
245251 lora_down_weight = value
246252 else :
247253 continue
248254
249255 # find corresponding lora_up and alpha
250256 block_up_name = block_down_name
251- lora_up_weight = lora_sd .get (block_up_name + f".{ LORAFMT [1 ]} . " + weight_name , None )
257+ lora_up_weight = lora_sd .get (block_up_name + f".{ LORAFMT [1 ]} " + weight_name , None )
252258 lora_alpha = lora_sd .get (block_down_name + ".alpha" , None )
253259
254260 weights_loaded = lora_down_weight is not None and lora_up_weight is not None
@@ -286,8 +292,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
286292 verbose_str += "\n "
287293
288294 new_alpha = param_dict ["new_alpha" ]
289- o_lora_sd [block_down_name + f".{ LORAFMT [0 ]} .weight" ] = param_dict [LORAFMT [0 ]].to (save_dtype ).contiguous ()
290- o_lora_sd [block_up_name + f".{ LORAFMT [1 ]} .weight" ] = param_dict [LORAFMT [1 ]].to (save_dtype ).contiguous ()
295+ o_lora_sd [block_down_name + f".{ LORAFMT [0 ]} " + weight_name ] = param_dict [LORAFMT [0 ]].to (save_dtype ).contiguous ()
296+ o_lora_sd [block_up_name + f".{ LORAFMT [1 ]} " + weight_name ] = param_dict [LORAFMT [1 ]].to (save_dtype ).contiguous ()
291297 o_lora_sd [block_up_name + ".alpha" ] = torch .tensor (param_dict ["new_alpha" ]).to (save_dtype )
292298
293299 block_down_name = None
0 commit comments