Skip to content

Commit e050fd8

Browse files
authored
[PEFT]:nola (#10747)
* 'nola提交' * Update test_nola.py * solve conflicts
1 parent e298ba4 commit e050fd8

File tree

8 files changed

+415
-4
lines changed

8 files changed

+415
-4
lines changed

llm/run_finetune.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,12 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
580580
use_quick_lora=model_args.use_quick_lora,
581581
lora_use_mixer=model_args.lora_use_mixer,
582582
use_mora=model_args.use_mora,
583+
<<<<<<< HEAD
584+
nola=model_args.nola,
585+
nola_basis_num=model_args.nola_basis_num,
586+
=======
583587
mixer_num=model_args.mixer_num,
588+
>>>>>>> upstream/develop
584589
lorapro=model_args.lorapro,
585590
)
586591
if model_args.lorapro:

paddlenlp/peft/lora/lora_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class LoRAConfig:
7777
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
7878
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
7979
loraga: bool = field(default=False, metadata={"help": "Whether to LoRA-GA"})
80+
nola: bool = field(default=False, metadata={"help": "Whether to use Nola: https://arxiv.org/pdf/2310.02556"})
81+
nola_basis_num: int = field(default=1, metadata={"help": "When use nola, the number of basis"})
8082
use_mora: bool = field(
8183
default=False, metadata={"help": "Whether to use MoRA: https://arxiv.org/pdf/2405.12130.pdf"}
8284
)

paddlenlp/peft/lora/lora_layers.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(
6363
rslora: bool = False,
6464
lora_plus_scale: float = 1.0,
6565
pissa: bool = False,
66+
nola: bool = False,
67+
nola_basis_num: int = 1,
6668
lora_use_mixer: bool = False,
6769
mixer_num: int = 1,
6870
use_mora: bool = False,
@@ -85,6 +87,8 @@ def __init__(
8587
# Mark the weight as unmerged
8688
self.merged = False
8789
self.pissa = pissa
90+
self.nola = nola
91+
self.nola_basis_num = nola_basis_num
8892
self.lora_use_mixer = lora_use_mixer
8993
self.mixer_num = mixer_num
9094
self.lorapro = lorapro
@@ -144,6 +148,32 @@ def __init__(
144148
),
145149
)
146150
self.apply_pissa = False
151+
if nola:
152+
# Initialize placeholders for NOLA parameters
153+
self.nola_basis_A = self.create_parameter(
154+
shape=[nola_basis_num, in_features, r],
155+
dtype=self._dtype,
156+
is_bias=False,
157+
)
158+
self.nola_basis_A.stop_gradient = True
159+
self.nola_basis_B = self.create_parameter(
160+
shape=[nola_basis_num, r, out_features],
161+
dtype=self._dtype,
162+
is_bias=False,
163+
)
164+
self.nola_basis_B.stop_gradient = True
165+
self.nola_alpha = self.create_parameter(
166+
shape=[nola_basis_num],
167+
dtype=self._dtype,
168+
is_bias=False,
169+
default_initializer=nn.initializer.Constant(value=0.0),
170+
)
171+
self.nola_beta = self.create_parameter(
172+
shape=[nola_basis_num],
173+
dtype=self._dtype,
174+
is_bias=False,
175+
default_initializer=nn.initializer.Constant(value=0.0),
176+
)
147177
if use_mora or pissa:
148178
self.scaling = 1.0
149179
elif not rslora:
@@ -179,6 +209,16 @@ def pissa_init(self, rank):
179209
weight = res.astype(dtype)
180210
self.weight.set_value(weight)
181211

212+
def get_nola_lora_matrices(self):
213+
"""Compute LoRA matrices A and B from NOLA basis and coefficients."""
214+
if not self.nola:
215+
return self.lora_A, self.lora_B
216+
# Compute A = sum(alpha_i * A_i)
217+
lora_A = paddle.einsum("k,kir->ir", self.nola_alpha, self.nola_basis_A) # [in_features, r]
218+
# Compute B = sum(beta_j * B_j)
219+
lora_B = paddle.einsum("k,kro->ro", self.nola_beta, self.nola_basis_B) # [r, out_features]
220+
return lora_A, lora_B
221+
182222
def rope_init(self):
183223
if self.cos is None or self.sin is None:
184224
inv_freq = 1.0 / (10000 ** (paddle.arange(0, self.r, 2, dtype=paddle.float32) / self.r))
@@ -257,6 +297,9 @@ def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None):
257297
w = w[: self.out_features]
258298
final_weight = w
259299
delta_weight = final_weight.T
300+
elif self.nola:
301+
lora_A, lora_B = self.get_nola_lora_matrices()
302+
delta_weight = lora_A @ lora_B * self.scaling
260303
else:
261304
lora_A = lora_A if lora_A is not None else self.lora_A
262305
lora_B = lora_B if lora_B is not None else self.lora_B
@@ -299,6 +342,11 @@ def forward(self, input: paddle.Tensor, *args, **kwargs):
299342
input = self.lora_dropout(input)
300343
mora_out = self._apply_mora(input)
301344
result += mora_out
345+
elif self.nola:
346+
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
347+
input = self.lora_dropout(input)
348+
lora_A, lora_B = self.get_nola_lora_matrices()
349+
result += (self.lora_dropout(input) @ lora_A @ lora_B) * self.scaling
302350
else:
303351
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
304352
if self.lora_use_mixer:
@@ -327,14 +375,16 @@ def __init__(
327375
use_quick_lora: bool = False,
328376
pissa: bool = False,
329377
use_mora: bool = False,
378+
nola: bool = False,
379+
nola_basis_num: int = 1,
330380
**kwargs
331381
):
332382
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
333383
if not isinstance(r, int) or r <= 0:
334384
raise ValueError("Lora rank r should be a positive integer")
335385

336-
if pissa or use_mora:
337-
raise ValueError("Pissa or Mora is not supported in model parallel by now")
386+
if pissa or use_mora or nola:
387+
raise ValueError("Pissa, Mora or NoLA is not supported in model parallel by now")
338388

339389
self.r = r
340390
self.lora_alpha = lora_alpha
@@ -593,14 +643,16 @@ def __init__(
593643
use_quick_lora: bool = False,
594644
pissa: bool = False,
595645
use_mora: bool = False,
646+
nola: bool = False,
647+
nola_basis_num: int = 1,
596648
**kwargs
597649
):
598650
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
599651
if not isinstance(r, int) or r <= 0:
600652
raise ValueError("Lora rank r should be a positive integer")
601653

602-
if pissa or use_mora:
603-
raise ValueError("Pissa or Mora is not supported in model parallel by now")
654+
if pissa or use_mora or nola:
655+
raise ValueError("Pissa, Mora or NoLA is not supported in model parallel by now")
604656

605657
self.r = r
606658
self.lora_alpha = lora_alpha

paddlenlp/peft/lora/lora_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ def _find_and_replace_module(self, model, module_name, lora_config):
482482
rslora=lora_config.rslora,
483483
lora_plus_scale=lora_config.lora_plus_scale,
484484
pissa=lora_config.pissa,
485+
nola=lora_config.nola,
486+
nola_basis_num=lora_config.nola_basis_num,
485487
bias_attr=False if module.bias is None else None,
486488
use_quick_lora=lora_config.use_quick_lora,
487489
lora_use_mixer=lora_config.lora_use_mixer,
@@ -521,6 +523,8 @@ def _find_and_replace_module(self, model, module_name, lora_config):
521523
rslora=lora_config.rslora,
522524
lora_plus_scale=lora_config.lora_plus_scale,
523525
pissa=lora_config.pissa,
526+
nola=lora_config.nola,
527+
nola_basis_num=lora_config.nola_basis_num,
524528
lora_A_weight_attr=paddle.ParamAttr(
525529
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
526530
),
@@ -547,6 +551,8 @@ def _find_and_replace_module(self, model, module_name, lora_config):
547551
rslora=lora_config.rslora,
548552
lora_plus_scale=lora_config.lora_plus_scale,
549553
pissa=lora_config.pissa,
554+
nola=lora_config.nola,
555+
nola_basis_num=lora_config.nola_basis_num,
550556
use_quick_lora=lora_config.use_quick_lora,
551557
)
552558
# Lora column parallel will spilt lora A matrix

paddlenlp/trl/model_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ModelConfig:
6464
lora_use_mixer: bool = field(
6565
default=False, metadata={"help": "Whether to use MosLoRA: https://arxiv.org/pdf/2406.11909"}
6666
)
67+
nola: bool = field(default=False, metadata={"help": "Whether to use Nola: https://arxiv.org/pdf/2310.02556"})
68+
nola_basis_num: int = field(default=1, metadata={"help": "When use nola, the number of basis"})
6769
mixer_num: int = field(default=1, metadata={"help": "Num of mixer matrices."})
6870
use_mora: bool = field(
6971
default=False, metadata={"help": "Whether to use MoRA: https://arxiv.org/pdf/2405.12130.pdf"}

tests/fixtures/llm/nola.yaml

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
nola:
2+
base:
3+
dataset_name_or_path: "./data"
4+
per_device_train_batch_size: 4
5+
gradient_accumulation_steps: 4
6+
per_device_eval_batch_size: 8
7+
eval_accumulation_steps: 16
8+
num_train_epochs: 3
9+
learning_rate: 3e-04
10+
warmup_steps: 30
11+
logging_steps: 1
12+
evaluation_strategy: "epoch"
13+
save_strategy: "epoch"
14+
src_length: 1024
15+
max_length: 2048
16+
fp16: true
17+
fp16_opt_level: "O2"
18+
do_train: true
19+
do_eval: true
20+
disable_tqdm: true
21+
load_best_model_at_end: true
22+
eval_with_do_generation: false
23+
metric_for_best_model: "accuracy"
24+
recompute: true
25+
save_total_limit: 1
26+
tensor_parallel_degree: 1
27+
pipeline_parallel_degree: 1
28+
lora: true
29+
nola: true
30+
nola_basis_num: 3
31+
32+
default:
33+
llama:
34+
model_name_or_path: __internal_testing__/tiny-random-llama
35+
chatglm:
36+
model_name_or_path: __internal_testing__/tiny-fused-chatglm
37+
chatglm2:
38+
model_name_or_path: __internal_testing__/tiny-fused-chatglm2
39+
bloom:
40+
model_name_or_path: __internal_testing__/tiny-fused-bloom
41+
qwen:
42+
model_name_or_path: __internal_testing__/tiny-fused-qwen
43+
baichuan:
44+
model_name_or_path: __internal_testing__/tiny-fused-baichuan
45+
46+
inference-predict:
47+
default:
48+
mode: dynamic
49+
max_length: 20
50+
batch_size: 2
51+
decode_strategy: greedy_search
52+
dtype: float16
53+
54+
inference-to-static:
55+
default:
56+
dtype: float16
57+
max_length: 20
58+
59+
inference-infer:
60+
default:
61+
mode: static
62+
dtype: float16
63+
batch_size: 2
64+
decode_strategy: greedy_search
65+
max_length: 20

tests/llm/test_nola.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import os
17+
import sys
18+
import unittest
19+
20+
import paddle
21+
from parameterized import parameterized_class
22+
23+
from tests.testing_utils import argv_context_guard, load_test_config
24+
25+
from .testing_utils import LLMTest
26+
27+
28+
@parameterized_class(
29+
["model_dir"],
30+
[
31+
["llama"],
32+
# ["chatglm"], @skip("Skip and wait to fix.")
33+
# ["chatglm2"], @skip("Skip and wait to fix.")
34+
# ["bloom"], @skip("Skip and wait to fix.")
35+
["qwen"],
36+
["baichuan"],
37+
],
38+
)
39+
class NolaTest(LLMTest, unittest.TestCase):
40+
config_path: str = "./tests/fixtures/llm/nola.yaml"
41+
model_dir: str = None
42+
43+
def setUp(self) -> None:
44+
LLMTest.setUp(self)
45+
46+
self.model_codes_dir = os.path.join(self.root_path, self.model_dir)
47+
sys.path.insert(0, self.model_codes_dir)
48+
49+
def tearDown(self) -> None:
50+
LLMTest.tearDown(self)
51+
sys.path.remove(self.model_codes_dir)
52+
53+
def test_nola(self):
54+
self.disable_static()
55+
paddle.set_default_dtype("float32")
56+
nola_config = load_test_config(self.config_path, "nola", self.model_dir)
57+
nola_config["output_dir"] = self.output_dir
58+
nola_config["dataset_name_or_path"] = self.data_dir
59+
60+
with argv_context_guard(nola_config):
61+
from run_finetune import main
62+
63+
main()
64+
65+
# merge weights
66+
merge_lora_weights_config = {
67+
"lora_path": nola_config["output_dir"],
68+
"model_name_or_path": nola_config["model_name_or_path"],
69+
"output_path": nola_config["output_dir"],
70+
}
71+
with argv_context_guard(merge_lora_weights_config):
72+
from tools.merge_lora_params import merge
73+
74+
merge()
75+
76+
# TODO(wj-Mcat): disable chatglm2 test temporarily
77+
if self.model_dir not in ["qwen", "baichuan", "chatglm2"]:
78+
self.run_predictor({"inference_model": True})
79+
80+
self.run_predictor({"inference_model": False})

0 commit comments

Comments
 (0)