@@ -63,6 +63,8 @@ def __init__(
63
63
rslora : bool = False ,
64
64
lora_plus_scale : float = 1.0 ,
65
65
pissa : bool = False ,
66
+ nola : bool = False ,
67
+ nola_basis_num : int = 1 ,
66
68
lora_use_mixer : bool = False ,
67
69
mixer_num : int = 1 ,
68
70
use_mora : bool = False ,
@@ -85,6 +87,8 @@ def __init__(
85
87
# Mark the weight as unmerged
86
88
self .merged = False
87
89
self .pissa = pissa
90
+ self .nola = nola
91
+ self .nola_basis_num = nola_basis_num
88
92
self .lora_use_mixer = lora_use_mixer
89
93
self .mixer_num = mixer_num
90
94
self .lorapro = lorapro
@@ -144,6 +148,32 @@ def __init__(
144
148
),
145
149
)
146
150
self .apply_pissa = False
151
+ if nola :
152
+ # Initialize placeholders for NOLA parameters
153
+ self .nola_basis_A = self .create_parameter (
154
+ shape = [nola_basis_num , in_features , r ],
155
+ dtype = self ._dtype ,
156
+ is_bias = False ,
157
+ )
158
+ self .nola_basis_A .stop_gradient = True
159
+ self .nola_basis_B = self .create_parameter (
160
+ shape = [nola_basis_num , r , out_features ],
161
+ dtype = self ._dtype ,
162
+ is_bias = False ,
163
+ )
164
+ self .nola_basis_B .stop_gradient = True
165
+ self .nola_alpha = self .create_parameter (
166
+ shape = [nola_basis_num ],
167
+ dtype = self ._dtype ,
168
+ is_bias = False ,
169
+ default_initializer = nn .initializer .Constant (value = 0.0 ),
170
+ )
171
+ self .nola_beta = self .create_parameter (
172
+ shape = [nola_basis_num ],
173
+ dtype = self ._dtype ,
174
+ is_bias = False ,
175
+ default_initializer = nn .initializer .Constant (value = 0.0 ),
176
+ )
147
177
if use_mora or pissa :
148
178
self .scaling = 1.0
149
179
elif not rslora :
@@ -179,6 +209,16 @@ def pissa_init(self, rank):
179
209
weight = res .astype (dtype )
180
210
self .weight .set_value (weight )
181
211
212
+ def get_nola_lora_matrices (self ):
213
+ """Compute LoRA matrices A and B from NOLA basis and coefficients."""
214
+ if not self .nola :
215
+ return self .lora_A , self .lora_B
216
+ # Compute A = sum(alpha_i * A_i)
217
+ lora_A = paddle .einsum ("k,kir->ir" , self .nola_alpha , self .nola_basis_A ) # [in_features, r]
218
+ # Compute B = sum(beta_j * B_j)
219
+ lora_B = paddle .einsum ("k,kro->ro" , self .nola_beta , self .nola_basis_B ) # [r, out_features]
220
+ return lora_A , lora_B
221
+
182
222
def rope_init (self ):
183
223
if self .cos is None or self .sin is None :
184
224
inv_freq = 1.0 / (10000 ** (paddle .arange (0 , self .r , 2 , dtype = paddle .float32 ) / self .r ))
@@ -257,6 +297,9 @@ def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None):
257
297
w = w [: self .out_features ]
258
298
final_weight = w
259
299
delta_weight = final_weight .T
300
+ elif self .nola :
301
+ lora_A , lora_B = self .get_nola_lora_matrices ()
302
+ delta_weight = lora_A @ lora_B * self .scaling
260
303
else :
261
304
lora_A = lora_A if lora_A is not None else self .lora_A
262
305
lora_B = lora_B if lora_B is not None else self .lora_B
@@ -299,6 +342,11 @@ def forward(self, input: paddle.Tensor, *args, **kwargs):
299
342
input = self .lora_dropout (input )
300
343
mora_out = self ._apply_mora (input )
301
344
result += mora_out
345
+ elif self .nola :
346
+ result = F .linear (x = input , weight = self .weight , bias = self .bias , name = self .name )
347
+ input = self .lora_dropout (input )
348
+ lora_A , lora_B = self .get_nola_lora_matrices ()
349
+ result += (self .lora_dropout (input ) @ lora_A @ lora_B ) * self .scaling
302
350
else :
303
351
result = F .linear (x = input , weight = self .weight , bias = self .bias , name = self .name )
304
352
if self .lora_use_mixer :
@@ -327,14 +375,16 @@ def __init__(
327
375
use_quick_lora : bool = False ,
328
376
pissa : bool = False ,
329
377
use_mora : bool = False ,
378
+ nola : bool = False ,
379
+ nola_basis_num : int = 1 ,
330
380
** kwargs
331
381
):
332
382
RowParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
333
383
if not isinstance (r , int ) or r <= 0 :
334
384
raise ValueError ("Lora rank r should be a positive integer" )
335
385
336
- if pissa or use_mora :
337
- raise ValueError ("Pissa or Mora is not supported in model parallel by now" )
386
+ if pissa or use_mora or nola :
387
+ raise ValueError ("Pissa, Mora or NoLA is not supported in model parallel by now" )
338
388
339
389
self .r = r
340
390
self .lora_alpha = lora_alpha
@@ -593,14 +643,16 @@ def __init__(
593
643
use_quick_lora : bool = False ,
594
644
pissa : bool = False ,
595
645
use_mora : bool = False ,
646
+ nola : bool = False ,
647
+ nola_basis_num : int = 1 ,
596
648
** kwargs
597
649
):
598
650
ColumnParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
599
651
if not isinstance (r , int ) or r <= 0 :
600
652
raise ValueError ("Lora rank r should be a positive integer" )
601
653
602
- if pissa or use_mora :
603
- raise ValueError ("Pissa or Mora is not supported in model parallel by now" )
654
+ if pissa or use_mora or nola :
655
+ raise ValueError ("Pissa, Mora or NoLA is not supported in model parallel by now" )
604
656
605
657
self .r = r
606
658
self .lora_alpha = lora_alpha
0 commit comments