Skip to content

Commit 7314063

Browse files
committed
fix import
1 parent dc5a6af commit 7314063

File tree

2 files changed

+344
-1
lines changed

2 files changed

+344
-1
lines changed

paddlenlp/peft/lora/lora_layers.py

Lines changed: 271 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,44 @@
2727

2828
from .lora_quick_layers import quick_lora
2929

30-
if "npu" in paddle.device.get_all_custom_device_type():
30+
31+
def is_mc2_valid():
32+
return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0"))
33+
34+
35+
if is_mc2_valid():
36+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
37+
MC2ColumnSeqParallelLinear,
38+
MC2RowSeqParallelLinear,
39+
)
40+
3141
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
3242
else:
3343
MC2LoRaRowParallelLinear = None
3444
MC2LoRaColumnParallelLinear = None
45+
MC2ColumnSeqParallelLinear = None
46+
MC2RowSeqParallelLinear = None
47+
48+
49+
try:
50+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
51+
AllGatherOp,
52+
ColumnSequenceParallelLinear,
53+
ReduceScatterOp,
54+
RowSequenceParallelLinear,
55+
mark_as_sequence_parallel_parameter,
56+
)
57+
except:
58+
59+
class ColumnSequenceParallelLinear:
60+
pass
61+
62+
class RowSequenceParallelLinear:
63+
pass
64+
65+
AllGatherOp = None
66+
ReduceScatterOp = None
67+
mark_as_sequence_parallel_parameter = None
3568

3669

3770
class LoRALinear(nn.Linear):
@@ -298,6 +331,123 @@ def extra_repr(self):
298331
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
299332

300333

334+
class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
335+
def __init__(
336+
self,
337+
in_features: int,
338+
out_features: int,
339+
r: int = 0,
340+
lora_alpha: int = 1,
341+
lora_dropout: float = 0.0,
342+
rslora: bool = False,
343+
lora_plus_scale: float = 1.0,
344+
merge_weights: bool = True,
345+
use_quick_lora: bool = False,
346+
pissa: bool = False,
347+
**kwargs
348+
):
349+
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
350+
if not isinstance(r, int) or r <= 0:
351+
raise ValueError("Lora rank r should be a positive integer")
352+
if pissa:
353+
raise ValueError("Pissa is not supported in model parallel by now")
354+
self.r = r
355+
self.lora_alpha = lora_alpha
356+
# Optional dropout
357+
if lora_dropout > 0.0:
358+
self.lora_dropout = nn.Dropout(p=lora_dropout)
359+
else:
360+
self.lora_dropout = lambda x: x
361+
# Mark the weight as unmerged
362+
self.merged = False
363+
self.merge_weights = merge_weights
364+
365+
# compatible
366+
self.name = self._name
367+
368+
# Actual trainable parameters
369+
self.lora_A = self.create_parameter(
370+
shape=[self.input_size_per_partition, r],
371+
dtype=self._dtype,
372+
is_bias=False,
373+
attr=paddle.ParamAttr(
374+
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
375+
),
376+
)
377+
self.lora_B = self.create_parameter(
378+
shape=[r, self.out_features],
379+
dtype=self._dtype,
380+
is_bias=False,
381+
attr=paddle.ParamAttr(
382+
initializer=paddle.nn.initializer.Constant(value=0.0),
383+
learning_rate=lora_plus_scale,
384+
),
385+
)
386+
387+
self.lora_A.is_distributed = True
388+
self.lora_A.split_axis = 0
389+
self.lora_B.is_distributed = False
390+
mark_as_sequence_parallel_parameter(self.lora_B)
391+
if not rslora:
392+
self.scaling = self.lora_alpha / self.r
393+
else:
394+
self.scaling = self.lora_alpha / math.sqrt(self.r)
395+
396+
# Freezing the pre-trained weight matrix
397+
self.weight.stop_gradient = True
398+
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
399+
400+
@property
401+
def use_quick_lora(self):
402+
# TODO(@gexiao): support qlora
403+
return False # self._use_quick_lora and self.training and not self.merged
404+
405+
def train(self):
406+
super().train()
407+
if self.merge_weights and self.merged:
408+
# Make sure that the weights are not merged
409+
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
410+
self.weight.set_value(new_weight)
411+
self.merged = False
412+
413+
def eval(self):
414+
super().eval()
415+
if self.merge_weights and not self.merged:
416+
# Merge the weights and mark it
417+
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
418+
self.weight.set_value(new_weight)
419+
self.merged = True
420+
421+
def forward(self, x: paddle.Tensor):
422+
if not self.input_is_parallel:
423+
input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
424+
else:
425+
input_mp = x
426+
427+
if not is_mc2_valid():
428+
output_parallel = self.linear(input_mp, self.weight, name=self._name)
429+
output_ = ReduceScatterOp.apply(output_parallel)
430+
result_mp = output_ + self.bias if self.bias is not None else output_
431+
else:
432+
output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
433+
result_mp = output_ + self.bias if self.bias is not None else output_
434+
435+
if not self.merged:
436+
input_mp = self.lora_dropout(input_mp)
437+
if not is_mc2_valid():
438+
input_mp = input_mp @ self.lora_A
439+
input_mp = ReduceScatterOp.apply(input_mp)
440+
else:
441+
input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
442+
delta_mp = (input_mp @ self.lora_B) * self.scaling
443+
result_mp += delta_mp
444+
return result_mp
445+
446+
def extra_repr(self):
447+
name = f", name={self.name}" if self.name else ""
448+
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
449+
450+
301451
class ColumnParallelLoRALinear(ColumnParallelLinear):
302452
def __init__(
303453
self,
@@ -428,6 +578,126 @@ def extra_repr(self):
428578
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
429579

430580

581+
class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
582+
def __init__(
583+
self,
584+
in_features: int,
585+
out_features: int,
586+
r: int = 0,
587+
lora_alpha: int = 1,
588+
lora_dropout: float = 0.0,
589+
rslora: bool = False,
590+
lora_plus_scale: float = 1.0,
591+
merge_weights: bool = True,
592+
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
593+
use_quick_lora: bool = False,
594+
pissa: bool = False,
595+
**kwargs
596+
):
597+
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
598+
if not isinstance(r, int) or r <= 0:
599+
raise ValueError("Lora rank r should be a positive integer")
600+
if pissa:
601+
raise ValueError("Pissa is not supported in model parallel by now")
602+
self.r = r
603+
self.lora_alpha = lora_alpha
604+
# Optional dropout
605+
if lora_dropout > 0.0:
606+
self.lora_dropout = nn.Dropout(p=lora_dropout)
607+
else:
608+
self.lora_dropout = lambda x: x
609+
# Mark the weight as unmerged
610+
self.merged = False
611+
self.merge_weights = merge_weights
612+
613+
# compatible
614+
self.name = self._name
615+
616+
# Actual trainable parameters
617+
self.lora_A = self.create_parameter(
618+
shape=[in_features, r],
619+
dtype=self._dtype,
620+
is_bias=False,
621+
attr=lora_A_weight_attr,
622+
)
623+
self.lora_A.is_distributed = False
624+
mark_as_sequence_parallel_parameter(self.lora_A)
625+
626+
self.lora_B = self.create_parameter(
627+
shape=[r, self.output_size_per_partition],
628+
dtype=self._dtype,
629+
is_bias=False,
630+
attr=paddle.ParamAttr(
631+
initializer=paddle.nn.initializer.Constant(value=0.0),
632+
learning_rate=lora_plus_scale,
633+
),
634+
)
635+
636+
self.lora_B.is_distributed = True
637+
self.lora_B.split_axis = 1
638+
if not rslora:
639+
self.scaling = self.lora_alpha / self.r
640+
else:
641+
self.scaling = self.lora_alpha / math.sqrt(self.r)
642+
643+
# Freezing the pre-trained weight matrix
644+
self.weight.stop_gradient = True
645+
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
646+
647+
@property
648+
def use_quick_lora(self):
649+
# TODO(@gexiao): support qlora
650+
return False # self._use_quick_lora and self.training and not self.merged
651+
652+
def train(self):
653+
super().train()
654+
if self.merge_weights and self.merged:
655+
# Make sure that the weights are not merged
656+
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
657+
self.weight.set_value(new_weight)
658+
self.merged = False
659+
660+
def eval(self):
661+
super().eval()
662+
if self.merge_weights and not self.merged:
663+
# Merge the weights and mark it
664+
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
665+
self.weight.set_value(new_weight)
666+
self.merged = True
667+
668+
def forward(self, x: paddle.Tensor):
669+
if not is_mc2_valid():
670+
if self.is_mp:
671+
input_parallel = AllGatherOp.apply(x)
672+
else:
673+
input_parallel = x
674+
result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
675+
else:
676+
result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group)
677+
if self.bias is not None:
678+
result_mp += self.bias
679+
680+
if not self.merged:
681+
input_a = self.lora_dropout(x) @ self.lora_A
682+
if not is_mc2_valid():
683+
input_a = AllGatherOp.apply(input_a)
684+
delta_mp = (input_a @ self.lora_B) * self.scaling
685+
else:
686+
input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
687+
delta_mp = input_a * self.scaling
688+
result_mp += delta_mp
689+
690+
if self.gather_output and self.is_mp:
691+
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
692+
else:
693+
result = result_mp
694+
return result
695+
696+
def extra_repr(self):
697+
name = f", name={self.name}" if self.name else ""
698+
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
699+
700+
431701
class LoRAMergedLinear(nn.Linear):
432702
# LoRA implemented in a dense layer with merged linear weights for q, k, v
433703
def __init__(

0 commit comments

Comments
 (0)