Skip to content

Commit 3ad71e1

Browse files
committed
Refactor to avoid mutable global variable
1 parent c6fab55 commit 3ad71e1

File tree

1 file changed

+43
-50
lines changed

1 file changed

+43
-50
lines changed

networks/resize_lora.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020

2121
MIN_SV = 1e-6
2222

23-
# Tune layers to various trainer formats.
24-
LORAFMT1 = ["lora_down", "lora_up"]
25-
LORAFMT2 = ["lora.down", "lora.up"]
26-
LORAFMT3 = ["lora_A", "lora_B"]
27-
LORAFMT4 = ["down", "up"]
28-
LORAFMT = LORAFMT1
23+
LORA_DOWN_UP_FORMATS = [
24+
("lora_down", "lora_up"), # sd-scripts LoRA
25+
("lora_A", "lora_B"), # PEFT LoRA
26+
("down", "up"), # ControlLoRA
27+
]
28+
2929

3030
# Model save and load functions
3131

@@ -97,8 +97,8 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
9797
U = U @ torch.diag(S)
9898
Vh = Vh[:lora_rank, :]
9999

100-
param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
101-
param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank, 1, 1).cpu()
100+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
101+
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
102102
del U, S, Vh, weight
103103
return param_dict
104104

@@ -116,8 +116,8 @@ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, sca
116116
U = U @ torch.diag(S)
117117
Vh = Vh[:lora_rank, :]
118118

119-
param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size).cpu()
120-
param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank).cpu()
119+
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
120+
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
121121
del U, S, Vh, weight
122122
return param_dict
123123

@@ -199,34 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
199199

200200

201201
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
202-
global LORAFMT
203-
network_alpha = None
204-
network_dim = None
202+
max_old_rank = None
203+
new_alpha = None
205204
verbose_str = "\n"
206205
fro_list = []
207206

208-
# Extract loaded lora dim and alpha
209-
for key, value in lora_sd.items():
210-
if network_alpha is None and "alpha" in key:
211-
network_alpha = value
212-
if (network_dim is None and len(value.size()) == 2
213-
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key or LORAFMT4[0] in key)):
214-
if LORAFMT1[0] in key:
215-
LORAFMT = LORAFMT1
216-
elif LORAFMT2[0] in key:
217-
LORAFMT = LORAFMT2
218-
elif LORAFMT3[0] in key:
219-
LORAFMT = LORAFMT3
220-
elif LORAFMT4[0] in key:
221-
LORAFMT = LORAFMT4
222-
network_dim = value.size()[0]
223-
if network_alpha is not None and network_dim is not None:
224-
break
225-
if network_alpha is None:
226-
network_alpha = network_dim
227-
228-
scale = network_alpha / network_dim
229-
230207
if dynamic_method:
231208
logger.info(
232209
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
@@ -241,31 +218,47 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
241218

242219
with torch.no_grad():
243220
for key, value in tqdm(lora_sd.items()):
244-
weight_name = None
245-
if LORAFMT[0] in key:
246-
block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0]
247-
if key.endswith(f".{LORAFMT[0]}"):
221+
key_parts = key.split(".")
222+
block_down_name = None
223+
for _format in LORA_DOWN_UP_FORMATS:
224+
# Currently we only match lora_down_name in the last two parts of key
225+
# because ("down", "up") are general words and may appear in block_down_name
226+
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
227+
block_down_name = ".".join(key_parts[:-2])
228+
lora_down_name = "." + _format[0]
229+
lora_up_name = "." + _format[1]
230+
weight_name = "." + key_parts[-1]
231+
break
232+
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
233+
block_down_name = ".".join(key_parts[:-1])
234+
lora_down_name = "." + _format[0]
235+
lora_up_name = "." + _format[1]
248236
weight_name = ""
249-
else:
250-
weight_name = key.rsplit(f".{LORAFMT[0]}", 1)[-1]
251-
lora_down_weight = value
252-
else:
237+
break
238+
239+
if block_down_name is None:
240+
# This parameter is not lora_down
253241
continue
254242

255-
# find corresponding lora_up and alpha
243+
# Now weight_name can be ".weight" or ""
244+
# Find corresponding lora_up and alpha
256245
block_up_name = block_down_name
257-
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}" + weight_name, None)
246+
lora_down_weight = value
247+
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
258248
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
259249

260250
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
261251

262252
if weights_loaded:
263253

264254
conv2d = len(lora_down_weight.size()) == 4
255+
old_rank = lora_down_weight.size()[0]
256+
max_old_rank = max(max_old_rank or 0, old_rank)
257+
265258
if lora_alpha is None:
266259
scale = 1.0
267260
else:
268-
scale = lora_alpha / lora_down_weight.size()[0]
261+
scale = lora_alpha / old_rank
269262

270263
if conv2d:
271264
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@@ -292,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
292285
verbose_str += "\n"
293286

294287
new_alpha = param_dict["new_alpha"]
295-
o_lora_sd[block_down_name + f".{LORAFMT[0]}" + weight_name] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
296-
o_lora_sd[block_up_name + f".{LORAFMT[1]}" + weight_name] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
297-
o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
288+
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
289+
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
290+
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
298291

299292
block_down_name = None
300293
block_up_name = None
@@ -307,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
307300
print(verbose_str)
308301
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
309302
logger.info("resizing complete")
310-
return o_lora_sd, network_dim, new_alpha
303+
return o_lora_sd, max_old_rank, new_alpha
311304

312305

313306
def resize(args):

0 commit comments

Comments
 (0)