Skip to content

Commit 0bad5ae

Browse files
committed
Detach and clone original LoRA weights before training
1 parent 3356314 commit 0bad5ae

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

networks/lora_flux.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ def __init__(
8585
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
8686

8787
if initialize == "urae":
88-
initialize_urae(org_module, self.lora_down, self.lora_up, self.lora_dim)
88+
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
8989
# Need to store the original weights so we can get a plain LoRA out
90-
self._org_lora_up = self.lora_up.weight.data
91-
self._org_lora_down = self.lora_down.weight.data
90+
self._org_lora_up = self.lora_up.weight.data.detach().clone()
91+
self._org_lora_down = self.lora_down.weight.data.detach().clone()
9292
elif initialize == "pissa":
9393
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
9494
# Need to store the original weights so we can get a plain LoRA out
95-
self._org_lora_up = self.lora_up.weight.data
96-
self._org_lora_down = self.lora_down.weight.data
95+
self._org_lora_up = self.lora_up.weight.data.detach().clone()
96+
self._org_lora_down = self.lora_down.weight.data.detach().clone()
9797
else:
9898
initialize_lora(self.lora_down, self.lora_up)
9999
else:
@@ -109,13 +109,13 @@ def __init__(
109109
if initialize == "urae":
110110
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
111111
# Need to store the original weights so we can get a plain LoRA out
112-
self._org_lora_up = lora_up.weight.data
113-
self._org_lora_down = lora_down.weight.data
112+
self._org_lora_up = lora_up.weight.data.detach().clone()
113+
self._org_lora_down = lora_down.weight.data.detach().clone()
114114
elif initialize == "pissa":
115115
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
116116
# Need to store the original weights so we can get a plain LoRA out
117-
self._org_lora_up = lora_up.weight.data
118-
self._org_lora_down = lora_down.weight.data
117+
self._org_lora_up = lora_up.weight.data.detach().clone()
118+
self._org_lora_down = lora_down.weight.data.detach().clone()
119119
else:
120120
initialize_lora(lora_down, lora_up)
121121

@@ -1101,19 +1101,19 @@ def save_weights(self, file, dtype, metadata):
11011101
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
11021102
# Calculate ΔW = A'B' - AB
11031103
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
1104-
1104+
11051105
# We need to create new low-rank matrices that represent this delta
11061106
# One approach is to do SVD on delta_w
11071107
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
1108-
1108+
11091109
# Take the top 2*r singular values (as suggested in the paper)
11101110
rank = rank * 2
11111111
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
1112-
1112+
11131113
# Create new LoRA matrices
11141114
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
11151115
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
1116-
1116+
11171117
# These matrices can now be used as standard LoRA weights
11181118
return new_up, new_down
11191119

0 commit comments

Comments
 (0)