Skip to content

Commit a6d3a28

Browse files
authored
Pissa (#8250)
* single gpu pissa impl * fix * scale update * raise error in mp * update pissa config * add pissa config
1 parent bd25e0c commit a6d3a28

File tree

8 files changed

+112
-2
lines changed

8 files changed

+112
-2
lines changed

llm/argument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class ModelArgument:
196196
)
197197
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
198198
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
199+
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
199200

200201
# prefix tuning related parameters
201202
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})

llm/finetune_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def neft_post_hook(module, input, output):
464464
lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4,
465465
rslora=model_args.rslora,
466466
lora_plus_scale=model_args.lora_plus_scale,
467+
pissa=model_args.pissa,
467468
merge_weights=False,
468469
tensor_parallel_degree=training_args.tensor_parallel_degree,
469470
dtype=dtype,

llm/llama/lora_argument.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@
2929
"lora": true,
3030
"zero_padding": false,
3131
"use_flash_attention": false
32-
}
32+
}

llm/llama/lora_argument_pissa.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model_name_or_path": "facebook/llama-7b",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/llama_lora_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 32,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 2e-05,
11+
"warmup_steps": 10,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"fp16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"lora": true,
30+
"pissa": false,
31+
"zero_padding": false,
32+
"use_flash_attention": false
33+
}

llm/qwen/lora_argument_pissa.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model_name_or_path": "qwen/qwen-7b",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/qwen_lora_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 32,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 2e-05,
11+
"warmup_steps": 10,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"bf16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"lora": true,
30+
"pissa": true,
31+
"zero_padding": false,
32+
"use_flash_attention": false
33+
}

paddlenlp/peft/lora/lora_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class LoRAConfig:
7474
)
7575
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
7676
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
77+
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
7778
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
7879
base_model_name_or_path: Optional[str] = field(
7980
default=None, metadata={"help": "The name of the base model to use."}

paddlenlp/peft/lora/lora_layers.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
use_quick_lora: bool = False,
4848
rslora: bool = False,
4949
lora_plus_scale: float = 1.0,
50+
pissa: bool = False,
5051
**kwargs
5152
):
5253
nn.Linear.__init__(self, in_features, out_features, **kwargs)
@@ -62,6 +63,7 @@ def __init__(
6263
# Mark the weight as unmerged
6364
self.merged = False
6465
self.merge_weights = merge_weights
66+
self.pissa = pissa
6567

6668
# Actual trainable parameters
6769
self.lora_A = self.create_parameter(
@@ -79,9 +81,12 @@ def __init__(
7981
learning_rate=lora_plus_scale,
8082
),
8183
)
84+
self.apply_pissa = False
8285

83-
if not rslora:
86+
if not rslora and not pissa:
8487
self.scaling = self.lora_alpha / self.r
88+
elif pissa:
89+
self.scaling = 1.0
8590
else:
8691
self.scaling = self.lora_alpha / math.sqrt(self.r)
8792

@@ -93,6 +98,25 @@ def __init__(
9398
def use_quick_lora(self):
9499
return self._use_quick_lora and self.training and not self.merged
95100

101+
def pissa_init(self, rank):
102+
weight = self.weight
103+
dtype = weight.dtype
104+
if dtype != paddle.float32:
105+
weight = weight.astype(paddle.float32)
106+
107+
U, S, Vh = paddle.linalg.svd(weight.data, full_matrices=False)
108+
Ur = U[:, :rank]
109+
Sr = S[:rank]
110+
Vhr = Vh[:rank]
111+
112+
lora_A = Ur @ paddle.diag(paddle.sqrt(Sr))
113+
lora_B = paddle.diag(paddle.sqrt(Sr)) @ Vhr
114+
self.lora_A.set_value(lora_A.astype(dtype))
115+
self.lora_B.set_value(lora_B.astype(dtype))
116+
res = weight.data - lora_A @ lora_B
117+
weight = res.astype(dtype)
118+
self.weight.set_value(weight)
119+
96120
def train(self):
97121
super().train()
98122
if self.merge_weights and self.merged:
@@ -110,6 +134,10 @@ def eval(self):
110134
self.merged = True
111135

112136
def forward(self, input: paddle.Tensor, *args, **kwargs):
137+
if not self.apply_pissa and self.pissa:
138+
self.pissa_init(self.r)
139+
self.apply_pissa = True
140+
113141
if self.use_quick_lora:
114142
# Use the quick lora implementation
115143
result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling)
@@ -136,11 +164,16 @@ def __init__(
136164
lora_plus_scale: float = 1.0,
137165
merge_weights: bool = True,
138166
use_quick_lora: bool = False,
167+
pissa: bool = False,
139168
**kwargs
140169
):
141170
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
142171
if not isinstance(r, int) or r <= 0:
143172
raise ValueError("Lora rank r should be a positive integer")
173+
174+
if pissa:
175+
raise ValueError("Pissa is not supported in model parallel by now")
176+
144177
self.r = r
145178
self.lora_alpha = lora_alpha
146179
# Optional dropout
@@ -278,11 +311,16 @@ def __init__(
278311
merge_weights: bool = True,
279312
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
280313
use_quick_lora: bool = False,
314+
pissa: bool = False,
281315
**kwargs
282316
):
283317
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
284318
if not isinstance(r, int) or r <= 0:
285319
raise ValueError("Lora rank r should be a positive integer")
320+
321+
if pissa:
322+
raise ValueError("Pissa is not supported in model parallel by now")
323+
286324
self.r = r
287325
self.lora_alpha = lora_alpha
288326
# Optional dropout

paddlenlp/peft/lora/lora_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
384384
merge_weights=lora_config.merge_weights,
385385
rslora=lora_config.rslora,
386386
lora_plus_scale=lora_config.lora_plus_scale,
387+
pissa=lora_config.pissa,
387388
bias_attr=False if module.bias is None else None,
388389
use_quick_lora=lora_config.use_quick_lora,
389390
)
@@ -417,6 +418,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
417418
lora_dropout=lora_config.lora_dropout,
418419
rslora=lora_config.rslora,
419420
lora_plus_scale=lora_config.lora_plus_scale,
421+
pissa=lora_config.pissa,
420422
merge_weights=lora_config.merge_weights,
421423
lora_A_weight_attr=paddle.ParamAttr(
422424
initializer=nn.initializer.KaimingUniform(
@@ -445,6 +447,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
445447
lora_dropout=lora_config.lora_dropout,
446448
rslora=lora_config.rslora,
447449
lora_plus_scale=lora_config.lora_plus_scale,
450+
pissa=lora_config.pissa,
448451
merge_weights=lora_config.merge_weights,
449452
use_quick_lora=lora_config.use_quick_lora,
450453
)

0 commit comments

Comments
 (0)