@@ -85,15 +85,15 @@ def __init__(
8585 self .lora_up = torch .nn .Linear (self .lora_dim , out_dim , bias = False )
8686
8787 if initialize == "urae" :
88- initialize_urae (org_module , self .lora_down , self .lora_up , self .lora_dim )
88+ initialize_urae (org_module , self .lora_down , self .lora_up , self .scale , self . lora_dim )
8989 # Need to store the original weights so we can get a plain LoRA out
90- self ._org_lora_up = self .lora_up .weight .data
91- self ._org_lora_down = self .lora_down .weight .data
90+ self ._org_lora_up = self .lora_up .weight .data . detach (). clone ()
91+ self ._org_lora_down = self .lora_down .weight .data . detach (). clone ()
9292 elif initialize == "pissa" :
9393 initialize_pissa (org_module , self .lora_down , self .lora_up , self .scale , self .lora_dim )
9494 # Need to store the original weights so we can get a plain LoRA out
95- self ._org_lora_up = self .lora_up .weight .data
96- self ._org_lora_down = self .lora_down .weight .data
95+ self ._org_lora_up = self .lora_up .weight .data . detach (). clone ()
96+ self ._org_lora_down = self .lora_down .weight .data . detach (). clone ()
9797 else :
9898 initialize_lora (self .lora_down , self .lora_up )
9999 else :
@@ -109,13 +109,13 @@ def __init__(
109109 if initialize == "urae" :
110110 initialize_urae (org_module , lora_down , lora_up , self .scale , self .lora_dim )
111111 # Need to store the original weights so we can get a plain LoRA out
112- self ._org_lora_up = lora_up .weight .data
113- self ._org_lora_down = lora_down .weight .data
112+ self ._org_lora_up = lora_up .weight .data . detach (). clone ()
113+ self ._org_lora_down = lora_down .weight .data . detach (). clone ()
114114 elif initialize == "pissa" :
115115 initialize_pissa (org_module , lora_down , lora_up , self .scale , self .lora_dim )
116116 # Need to store the original weights so we can get a plain LoRA out
117- self ._org_lora_up = lora_up .weight .data
118- self ._org_lora_down = lora_down .weight .data
117+ self ._org_lora_up = lora_up .weight .data . detach (). clone ()
118+ self ._org_lora_down = lora_down .weight .data . detach (). clone ()
119119 else :
120120 initialize_lora (lora_down , lora_up )
121121
@@ -1101,19 +1101,19 @@ def save_weights(self, file, dtype, metadata):
11011101 def convert_pissa_to_standard_lora (trained_up : Tensor , trained_down : Tensor , orig_up : Tensor , orig_down : Tensor , rank : int ):
11021102 # Calculate ΔW = A'B' - AB
11031103 delta_w = (trained_up @ trained_down ) - (orig_up @ orig_down )
1104-
1104+
11051105 # We need to create new low-rank matrices that represent this delta
11061106 # One approach is to do SVD on delta_w
11071107 U , S , V = torch .linalg .svd (delta_w , full_matrices = False )
1108-
1108+
11091109 # Take the top 2*r singular values (as suggested in the paper)
11101110 rank = rank * 2
11111111 rank = min (rank , len (S )) # Make sure we don't exceed available singular values
1112-
1112+
11131113 # Create new LoRA matrices
11141114 new_up = U [:, :rank ] @ torch .diag (torch .sqrt (S [:rank ]))
11151115 new_down = torch .diag (torch .sqrt (S [:rank ])) @ V [:rank , :]
1116-
1116+
11171117 # These matrices can now be used as standard LoRA weights
11181118 return new_up , new_down
11191119
0 commit comments