Skip to content

Commit c6fab55

Browse files
committed
Support resizing ControlLoRA
1 parent 3ce0c6e commit c6fab55

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

networks/resize_lora.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LORAFMT1 = ["lora_down", "lora_up"]
2525
LORAFMT2 = ["lora.down", "lora.up"]
2626
LORAFMT3 = ["lora_A", "lora_B"]
27+
LORAFMT4 = ["down", "up"]
2728
LORAFMT = LORAFMT1
2829

2930
# Model save and load functions
@@ -209,13 +210,15 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
209210
if network_alpha is None and "alpha" in key:
210211
network_alpha = value
211212
if (network_dim is None and len(value.size()) == 2
212-
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key)):
213+
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key or LORAFMT4[0] in key)):
213214
if LORAFMT1[0] in key:
214215
LORAFMT = LORAFMT1
215216
elif LORAFMT2[0] in key:
216217
LORAFMT = LORAFMT2
217218
elif LORAFMT3[0] in key:
218219
LORAFMT = LORAFMT3
220+
elif LORAFMT4[0] in key:
221+
LORAFMT = LORAFMT4
219222
network_dim = value.size()[0]
220223
if network_alpha is not None and network_dim is not None:
221224
break
@@ -241,14 +244,17 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
241244
weight_name = None
242245
if LORAFMT[0] in key:
243246
block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0]
244-
weight_name = key.rsplit(".", 1)[-1]
247+
if key.endswith(f".{LORAFMT[0]}"):
248+
weight_name = ""
249+
else:
250+
weight_name = key.rsplit(f".{LORAFMT[0]}", 1)[-1]
245251
lora_down_weight = value
246252
else:
247253
continue
248254

249255
# find corresponding lora_up and alpha
250256
block_up_name = block_down_name
251-
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}." + weight_name, None)
257+
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}" + weight_name, None)
252258
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
253259

254260
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
@@ -286,8 +292,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
286292
verbose_str += "\n"
287293

288294
new_alpha = param_dict["new_alpha"]
289-
o_lora_sd[block_down_name + f".{LORAFMT[0]}.weight"] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
290-
o_lora_sd[block_up_name + f".{LORAFMT[1]}.weight"] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
295+
o_lora_sd[block_down_name + f".{LORAFMT[0]}" + weight_name] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
296+
o_lora_sd[block_up_name + f".{LORAFMT[1]}" + weight_name] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
291297
o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
292298

293299
block_down_name = None

0 commit comments

Comments
 (0)