Skip to content

Commit e298ba4

Browse files
Liebelebruicecode
andauthored
LoriKiT (#9776)
* add linchain * resolve confilct * add linchain test * resolve pre-commit * resolve ci problem and add description. --------- Co-authored-by: moge <[email protected]>
1 parent 19a585b commit e298ba4

File tree

10 files changed

+567
-19
lines changed

10 files changed

+567
-19
lines changed

llm/docs/finetune.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.
9494
3. 可以通过设置`weight_quantize_algo`将主干模型量化低比特,例如'weight_only_int4','weight_only_int8','nf4'或'fp4'。具体参考精调参数介绍
9595
4. 设置`use_flash_attention`为 True 使用 FlashAttention。在 FlashAttention 打开的基础上设置`flash_mask`为 True 使用 FlashMask。
9696
5. LoRA API 支持4D 并行策略,可以通过控制`tensor_parallel_degree``pipeline_parallel_degree``sharding``sharding_parallel_degree`调整并行训练策略,可拓展至**单机 LoRA 微调千亿模型**
97-
6. 可配置`rslora``lora_plus_scale``pissa``lora_use_mixer``use_mora`等参数,使用 rsLoRA、LoRa+、PiSSA、MosLoRA(暂不支持张量模型并行)、MoRA(暂不支持张量模型并行) 等算法。
97+
6. 可配置`rslora``lora_plus_scale``pissa``lora_use_mixer``mixer_num``use_mora`等参数,使用 rsLoRA、LoRa+、PiSSA、MosLoRA(暂不支持张量模型并行)、LinChain(暂不支持张量模型并行)、MoRA(暂不支持张量模型并行) 等算法。
9898

9999
为了后续的**压缩****静态图推理**方便,我们提供 LoRA 参数合并脚本,可以将 LoRA 参数合并到主干模型并保存相应的权重。
100100
```

llm/run_finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
580580
use_quick_lora=model_args.use_quick_lora,
581581
lora_use_mixer=model_args.lora_use_mixer,
582582
use_mora=model_args.use_mora,
583+
mixer_num=model_args.mixer_num,
583584
lorapro=model_args.lorapro,
584585
)
585586
if model_args.lorapro:

llm/tools/merge_lora_params.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,51 +78,68 @@ def weight_process(name, quant_config, lora_config, state_dict, device):
7878
raise ValueError(f"quant_config.weight_quantize_algo {quant_config.weight_quantize_algo} is not supported.")
7979

8080

81+
def get_mixer(mixer, mixer_num, index=0):
82+
if index == mixer_num - 1:
83+
return mixer[index]
84+
else:
85+
return mixer[index] @ get_mixer(mixer, mixer_num, index + 1)
86+
87+
8188
def lora_process(name, layer, lora_config, state_dict, device, lora_state_dict=None):
89+
8290
target_device = device if device == "cpu" else device + ":0"
8391

8492
if (name + ".weight") not in state_dict.keys():
8593
return
8694

8795
weight = state_dict.pop(name + ".weight")
8896
lora_use_mixer = lora_config.lora_use_mixer
97+
98+
mixer_num = lora_config.mixer_num
99+
mixer = {}
89100
use_mora = lora_config.use_mora
101+
90102
if lora_state_dict is None:
91103
lora_A = state_dict.pop(name + ".lora_A")
92104
if not use_mora:
93105
lora_B = state_dict.pop(name + ".lora_B")
94106
if lora_use_mixer:
95-
lora_AB = state_dict.pop(name + ".lora_AB")
107+
for i in range(mixer_num):
108+
mixer[i] = state_dict.pop(name + ".lora_mixer_" + str(i))
96109
else:
97110
lora_A = lora_state_dict.pop(name + ".lora_A")
98111
if not use_mora:
99112
lora_B = lora_state_dict.pop(name + ".lora_B")
100113
if lora_use_mixer:
101-
lora_AB = lora_state_dict.pop(name + ".lora_AB")
114+
for i in range(mixer_num):
115+
mixer[i] = state_dict.pop(name + ".lora_mixer_" + str(i))
102116
if device != "cpu":
103117
weight = weight.to(target_device)
104118
lora_A = lora_A.to(target_device)
105119
if not use_mora:
106120
lora_B = lora_B.to(target_device)
107121
if lora_use_mixer:
108-
lora_AB = lora_AB.to(target_device)
122+
for key in mixer.keys():
123+
mixer[key] = mixer[key].to(target_device)
109124

110125
if device == "cpu" and weight.dtype.name == "BF16":
111126
weight = weight.astype("float32")
112127
lora_A = lora_A.astype("float32")
113128
if not use_mora:
114129
lora_B = lora_B.astype("float32")
130+
115131
if lora_use_mixer:
116-
lora_AB = lora_AB.astype(lora_config.dtype)
117-
delta_weight = layer.get_delta_weight(lora_A, lora_B, lora_AB)
132+
for key in mixer.keys():
133+
mixer[key] = mixer[key].astype(lora_config.dtype)
134+
delta_weight = layer.get_delta_weight(lora_A, lora_B, get_mixer(mixer, mixer_num))
118135
elif use_mora:
119136
delta_weight = layer.get_delta_weight(lora_A)
120137
else:
121138
delta_weight = layer.get_delta_weight(lora_A, lora_B)
122139
out = (weight + delta_weight).astype(lora_config.dtype)
123140
else:
124141
if lora_use_mixer:
125-
delta_weight = layer.get_delta_weight(lora_A, lora_B, lora_AB)
142+
delta_weight = layer.get_delta_weight(lora_A, lora_B, get_mixer(mixer, mixer_num))
126143
elif use_mora:
127144
delta_weight = layer.get_delta_weight(lora_A)
128145
else:

paddlenlp/peft/lora/lora_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ class LoRAConfig:
9494
default=False,
9595
metadata={"help": "Whether to use mos lora."},
9696
)
97+
mixer_num: int = field(
98+
default=1,
99+
metadata={
100+
"help": "Num of mixer matrices. Mixer matrices will be added between the LoRA_A and LoRA_B matrices, as referenced in the paper https://arxiv.org/abs/2411.00039."
101+
},
102+
)
97103
lorapro: bool = field(default=False, metadata={"help": "Whether to use LoRA-PRO"})
98104

99105
def __post_init__(self):

paddlenlp/peft/lora/lora_layers.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
lora_plus_scale: float = 1.0,
6565
pissa: bool = False,
6666
lora_use_mixer: bool = False,
67+
mixer_num: int = 1,
6768
use_mora: bool = False,
6869
lorapro: bool = False,
6970
mp_moe: bool = False,
@@ -85,6 +86,7 @@ def __init__(
8586
self.merged = False
8687
self.pissa = pissa
8788
self.lora_use_mixer = lora_use_mixer
89+
self.mixer_num = mixer_num
8890
self.lorapro = lorapro
8991

9092
# Actual trainable parameters
@@ -118,14 +120,20 @@ def __init__(
118120
),
119121
)
120122
if self.lora_use_mixer:
121-
self.lora_AB = self.create_parameter(
122-
shape=[r, r],
123-
dtype=self._dtype,
124-
is_bias=False,
125-
default_initializer=nn.initializer.KaimingUniform(
126-
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
127-
),
128-
)
123+
for i in range(self.mixer_num):
124+
key = "lora_mixer_" + str(i)
125+
setattr(
126+
self,
127+
key,
128+
self.create_parameter(
129+
shape=[r, r],
130+
dtype=self._dtype,
131+
is_bias=False,
132+
default_initializer=nn.initializer.KaimingUniform(
133+
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
134+
),
135+
),
136+
)
129137
self.lora_B = self.create_parameter(
130138
shape=[r, out_features],
131139
dtype=self._dtype,
@@ -221,7 +229,7 @@ def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None):
221229
if self.lora_use_mixer:
222230
lora_A = lora_A if lora_A is not None else self.lora_A
223231
lora_B = lora_B if lora_B is not None else self.lora_B
224-
lora_AB = lora_AB if lora_AB is not None else self.lora_AB
232+
lora_AB = lora_AB if lora_AB is not None else self.get_mixer_params(0)
225233
delta_weight = lora_A @ lora_AB @ lora_B * self.scaling
226234
elif self.use_mora:
227235
lora_A = lora_A if lora_A is not None else self.lora_A
@@ -256,18 +264,25 @@ def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None):
256264

257265
return delta_weight
258266

267+
def get_mixer_params(self, index):
268+
key = "lora_mixer_" + str(index)
269+
if index == self.mixer_num - 1:
270+
return getattr(self, key)
271+
else:
272+
return getattr(self, key) @ self.get_mixer_params(index + 1)
273+
259274
def merge(self):
260275
if not self.merged:
261276
delta_weight = self.get_delta_weight()
262277
new_weight = self.weight + delta_weight
263-
self.weight.set_value(new_weight)
278+
self.weight.set_value(new_weight.astype(self.weight.dtype))
264279
self.merged = True
265280

266281
def unmerge(self):
267282
if self.merged:
268283
delta_weight = self.get_delta_weight()
269284
new_weight = self.weight - delta_weight
270-
self.weight.set_value(new_weight)
285+
self.weight.set_value(new_weight.astype(self.weight.dtype))
271286
self.merged = False
272287

273288
def forward(self, input: paddle.Tensor, *args, **kwargs):
@@ -287,7 +302,9 @@ def forward(self, input: paddle.Tensor, *args, **kwargs):
287302
else:
288303
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
289304
if self.lora_use_mixer:
290-
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_AB @ self.lora_B) * self.scaling
305+
result += (
306+
self.lora_dropout(input) @ self.lora_A @ self.get_mixer_params(0) @ self.lora_B
307+
) * self.scaling
291308
else:
292309
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
293310
return result

paddlenlp/peft/lora/lora_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def _find_and_replace_module(self, model, module_name, lora_config):
485485
bias_attr=False if module.bias is None else None,
486486
use_quick_lora=lora_config.use_quick_lora,
487487
lora_use_mixer=lora_config.lora_use_mixer,
488+
mixer_num=lora_config.mixer_num,
488489
use_mora=lora_config.use_mora,
489490
mp_moe=getattr(module.weight, "mp_moe", False),
490491
is_distributed=getattr(module.weight, "is_distributed", False),

paddlenlp/trl/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class ModelConfig:
6464
lora_use_mixer: bool = field(
6565
default=False, metadata={"help": "Whether to use MosLoRA: https://arxiv.org/pdf/2406.11909"}
6666
)
67+
mixer_num: int = field(default=1, metadata={"help": "Num of mixer matrices."})
6768
use_mora: bool = field(
6869
default=False, metadata={"help": "Whether to use MoRA: https://arxiv.org/pdf/2405.12130.pdf"}
6970
)

tests/fixtures/llm/linchain.yaml

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
lora:
2+
base:
3+
dataset_name_or_path: "./data"
4+
per_device_train_batch_size: 4
5+
gradient_accumulation_steps: 4
6+
per_device_eval_batch_size: 8
7+
eval_accumulation_steps: 16
8+
num_train_epochs: 3
9+
learning_rate: 3e-04
10+
warmup_steps: 30
11+
logging_steps: 1
12+
evaluation_strategy: "epoch"
13+
save_strategy: "epoch"
14+
src_length: 1024
15+
max_length: 2048
16+
fp16: true
17+
fp16_opt_level: "O2"
18+
do_train: true
19+
do_eval: true
20+
disable_tqdm: true
21+
load_best_model_at_end: true
22+
eval_with_do_generation: false
23+
metric_for_best_model: "accuracy"
24+
recompute: true
25+
save_total_limit: 1
26+
tensor_parallel_degree: 1
27+
pipeline_parallel_degree: 1
28+
lora: true
29+
lora_use_mixer: true
30+
mixer_num: 3
31+
32+
default:
33+
llama:
34+
model_name_or_path: __internal_testing__/tiny-random-llama
35+
chatglm:
36+
model_name_or_path: __internal_testing__/tiny-fused-chatglm
37+
chatglm2:
38+
model_name_or_path: __internal_testing__/tiny-fused-chatglm2
39+
bloom:
40+
model_name_or_path: __internal_testing__/tiny-fused-bloom
41+
qwen:
42+
model_name_or_path: __internal_testing__/tiny-fused-qwen
43+
qwen2:
44+
model_name_or_path: __internal_testing__/tiny-random-qwen2
45+
qwen2moe:
46+
model_name_or_path: __internal_testing__/tiny-random-qwen2moe
47+
baichuan:
48+
model_name_or_path: __internal_testing__/tiny-fused-baichuan
49+
50+
rslora_plus:
51+
base:
52+
dataset_name_or_path: "./data"
53+
per_device_train_batch_size: 4
54+
gradient_accumulation_steps: 4
55+
per_device_eval_batch_size: 8
56+
eval_accumulation_steps: 16
57+
num_train_epochs: 3
58+
learning_rate: 3e-04
59+
warmup_steps: 30
60+
logging_steps: 1
61+
evaluation_strategy: "epoch"
62+
save_strategy: "epoch"
63+
src_length: 1024
64+
max_length: 2048
65+
fp16: true
66+
fp16_opt_level: "O2"
67+
do_train: true
68+
do_eval: true
69+
disable_tqdm: true
70+
load_best_model_at_end: true
71+
eval_with_do_generation: false
72+
metric_for_best_model: "accuracy"
73+
recompute: true
74+
save_total_limit: 1
75+
tensor_parallel_degree: 1
76+
pipeline_parallel_degree: 1
77+
lora: true
78+
lora_plus_scale: 4
79+
rslora: true
80+
81+
default:
82+
llama:
83+
model_name_or_path: __internal_testing__/tiny-random-llama
84+
chatglm:
85+
model_name_or_path: __internal_testing__/tiny-fused-chatglm
86+
chatglm2:
87+
model_name_or_path: __internal_testing__/tiny-fused-chatglm2
88+
bloom:
89+
model_name_or_path: __internal_testing__/tiny-fused-bloom
90+
qwen:
91+
model_name_or_path: __internal_testing__/tiny-fused-qwen
92+
baichuan:
93+
model_name_or_path: __internal_testing__/tiny-fused-baichuan
94+
95+
inference-predict:
96+
default:
97+
mode: dynamic
98+
max_length: 20
99+
batch_size: 2
100+
decode_strategy: greedy_search
101+
dtype: float16
102+
103+
inference-to-static:
104+
default:
105+
dtype: float16
106+
max_length: 20
107+
108+
inference-infer:
109+
default:
110+
mode: static
111+
dtype: float16
112+
batch_size: 2
113+
decode_strategy: greedy_search
114+
max_length: 20

0 commit comments

Comments
 (0)