@@ -47,6 +47,7 @@ def __init__(
47
47
use_quick_lora : bool = False ,
48
48
rslora : bool = False ,
49
49
lora_plus_scale : float = 1.0 ,
50
+ pissa : bool = False ,
50
51
** kwargs
51
52
):
52
53
nn .Linear .__init__ (self , in_features , out_features , ** kwargs )
@@ -62,6 +63,7 @@ def __init__(
62
63
# Mark the weight as unmerged
63
64
self .merged = False
64
65
self .merge_weights = merge_weights
66
+ self .pissa = pissa
65
67
66
68
# Actual trainable parameters
67
69
self .lora_A = self .create_parameter (
@@ -79,9 +81,12 @@ def __init__(
79
81
learning_rate = lora_plus_scale ,
80
82
),
81
83
)
84
+ self .apply_pissa = False
82
85
83
- if not rslora :
86
+ if not rslora and not pissa :
84
87
self .scaling = self .lora_alpha / self .r
88
+ elif pissa :
89
+ self .scaling = 1.0
85
90
else :
86
91
self .scaling = self .lora_alpha / math .sqrt (self .r )
87
92
@@ -93,6 +98,25 @@ def __init__(
93
98
def use_quick_lora (self ):
94
99
return self ._use_quick_lora and self .training and not self .merged
95
100
101
+ def pissa_init (self , rank ):
102
+ weight = self .weight
103
+ dtype = weight .dtype
104
+ if dtype != paddle .float32 :
105
+ weight = weight .astype (paddle .float32 )
106
+
107
+ U , S , Vh = paddle .linalg .svd (weight .data , full_matrices = False )
108
+ Ur = U [:, :rank ]
109
+ Sr = S [:rank ]
110
+ Vhr = Vh [:rank ]
111
+
112
+ lora_A = Ur @ paddle .diag (paddle .sqrt (Sr ))
113
+ lora_B = paddle .diag (paddle .sqrt (Sr )) @ Vhr
114
+ self .lora_A .set_value (lora_A .astype (dtype ))
115
+ self .lora_B .set_value (lora_B .astype (dtype ))
116
+ res = weight .data - lora_A @ lora_B
117
+ weight = res .astype (dtype )
118
+ self .weight .set_value (weight )
119
+
96
120
def train (self ):
97
121
super ().train ()
98
122
if self .merge_weights and self .merged :
@@ -110,6 +134,10 @@ def eval(self):
110
134
self .merged = True
111
135
112
136
def forward (self , input : paddle .Tensor , * args , ** kwargs ):
137
+ if not self .apply_pissa and self .pissa :
138
+ self .pissa_init (self .r )
139
+ self .apply_pissa = True
140
+
113
141
if self .use_quick_lora :
114
142
# Use the quick lora implementation
115
143
result = quick_lora (input , self .lora_A , self .lora_B , self .weight , self .bias , self .scaling )
@@ -136,11 +164,16 @@ def __init__(
136
164
lora_plus_scale : float = 1.0 ,
137
165
merge_weights : bool = True ,
138
166
use_quick_lora : bool = False ,
167
+ pissa : bool = False ,
139
168
** kwargs
140
169
):
141
170
RowParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
142
171
if not isinstance (r , int ) or r <= 0 :
143
172
raise ValueError ("Lora rank r should be a positive integer" )
173
+
174
+ if pissa :
175
+ raise ValueError ("Pissa is not supported in model parallel by now" )
176
+
144
177
self .r = r
145
178
self .lora_alpha = lora_alpha
146
179
# Optional dropout
@@ -278,11 +311,16 @@ def __init__(
278
311
merge_weights : bool = True ,
279
312
lora_A_weight_attr : Optional [paddle .ParamAttr ] = None ,
280
313
use_quick_lora : bool = False ,
314
+ pissa : bool = False ,
281
315
** kwargs
282
316
):
283
317
ColumnParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
284
318
if not isinstance (r , int ) or r <= 0 :
285
319
raise ValueError ("Lora rank r should be a positive integer" )
320
+
321
+ if pissa :
322
+ raise ValueError ("Pissa is not supported in model parallel by now" )
323
+
286
324
self .r = r
287
325
self .lora_alpha = lora_alpha
288
326
# Optional dropout
0 commit comments