Skip to content

Commit 2857f21

Browse files
authored
Merge pull request #2175 from woct0rdho/resize-control-lora
Resize ControlLoRA
2 parents a21b6a9 + 3ad71e1 commit 2857f21

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

networks/resize_lora.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
MIN_SV = 1e-6
2222

23+
LORA_DOWN_UP_FORMATS = [
24+
("lora_down", "lora_up"), # sd-scripts LoRA
25+
("lora_A", "lora_B"), # PEFT LoRA
26+
("down", "up"), # ControlLoRA
27+
]
28+
29+
2330
# Model save and load functions
2431

2532

@@ -192,24 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
192199

193200

194201
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
195-
network_alpha = None
196-
network_dim = None
202+
max_old_rank = None
203+
new_alpha = None
197204
verbose_str = "\n"
198205
fro_list = []
199206

200-
# Extract loaded lora dim and alpha
201-
for key, value in lora_sd.items():
202-
if network_alpha is None and "alpha" in key:
203-
network_alpha = value
204-
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
205-
network_dim = value.size()[0]
206-
if network_alpha is not None and network_dim is not None:
207-
break
208-
if network_alpha is None:
209-
network_alpha = network_dim
210-
211-
scale = network_alpha / network_dim
212-
213207
if dynamic_method:
214208
logger.info(
215209
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
@@ -224,28 +218,47 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
224218

225219
with torch.no_grad():
226220
for key, value in tqdm(lora_sd.items()):
227-
weight_name = None
228-
if "lora_down" in key:
229-
block_down_name = key.rsplit(".lora_down", 1)[0]
230-
weight_name = key.rsplit(".", 1)[-1]
231-
lora_down_weight = value
232-
else:
221+
key_parts = key.split(".")
222+
block_down_name = None
223+
for _format in LORA_DOWN_UP_FORMATS:
224+
# Currently we only match lora_down_name in the last two parts of key
225+
# because ("down", "up") are general words and may appear in block_down_name
226+
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
227+
block_down_name = ".".join(key_parts[:-2])
228+
lora_down_name = "." + _format[0]
229+
lora_up_name = "." + _format[1]
230+
weight_name = "." + key_parts[-1]
231+
break
232+
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
233+
block_down_name = ".".join(key_parts[:-1])
234+
lora_down_name = "." + _format[0]
235+
lora_up_name = "." + _format[1]
236+
weight_name = ""
237+
break
238+
239+
if block_down_name is None:
240+
# This parameter is not lora_down
233241
continue
234242

235-
# find corresponding lora_up and alpha
243+
# Now weight_name can be ".weight" or ""
244+
# Find corresponding lora_up and alpha
236245
block_up_name = block_down_name
237-
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
246+
lora_down_weight = value
247+
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
238248
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
239249

240250
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
241251

242252
if weights_loaded:
243253

244254
conv2d = len(lora_down_weight.size()) == 4
255+
old_rank = lora_down_weight.size()[0]
256+
max_old_rank = max(max_old_rank or 0, old_rank)
257+
245258
if lora_alpha is None:
246259
scale = 1.0
247260
else:
248-
scale = lora_alpha / lora_down_weight.size()[0]
261+
scale = lora_alpha / old_rank
249262

250263
if conv2d:
251264
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@@ -272,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
272285
verbose_str += "\n"
273286

274287
new_alpha = param_dict["new_alpha"]
275-
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
276-
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
277-
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
288+
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
289+
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
290+
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
278291

279292
block_down_name = None
280293
block_up_name = None
@@ -287,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
287300
print(verbose_str)
288301
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
289302
logger.info("resizing complete")
290-
return o_lora_sd, network_dim, new_alpha
303+
return o_lora_sd, max_old_rank, new_alpha
291304

292305

293306
def resize(args):

0 commit comments

Comments
 (0)