|
27 | 27 |
|
28 | 28 | from .lora_quick_layers import quick_lora
|
29 | 29 |
|
30 |
| - |
31 |
| -def is_mc2_valid(): |
32 |
| - return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")) |
33 |
| - |
34 |
| - |
35 |
| -if is_mc2_valid(): |
36 |
| - from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
37 |
| - MC2ColumnSeqParallelLinear, |
38 |
| - MC2RowSeqParallelLinear, |
39 |
| - ) |
40 |
| - |
| 30 | +if "npu" in paddle.device.get_all_custom_device_type(): |
41 | 31 | from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
|
42 | 32 | else:
|
43 | 33 | MC2LoRaRowParallelLinear = None
|
44 | 34 | MC2LoRaColumnParallelLinear = None
|
45 |
| - MC2ColumnSeqParallelLinear = None |
46 |
| - MC2RowSeqParallelLinear = None |
47 |
| - |
48 |
| - |
49 |
| -try: |
50 |
| - from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
51 |
| - AllGatherOp, |
52 |
| - ColumnSequenceParallelLinear, |
53 |
| - ReduceScatterOp, |
54 |
| - RowSequenceParallelLinear, |
55 |
| - mark_as_sequence_parallel_parameter, |
56 |
| - ) |
57 |
| -except: |
58 |
| - |
59 |
| - class ColumnSequenceParallelLinear: |
60 |
| - pass |
61 |
| - |
62 |
| - class RowSequenceParallelLinear: |
63 |
| - pass |
64 |
| - |
65 |
| - AllGatherOp = None |
66 |
| - ReduceScatterOp = None |
67 |
| - mark_as_sequence_parallel_parameter = None |
68 | 35 |
|
69 | 36 |
|
70 | 37 | class LoRALinear(nn.Linear):
|
@@ -331,123 +298,6 @@ def extra_repr(self):
|
331 | 298 | return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
|
332 | 299 |
|
333 | 300 |
|
334 |
| -class RowSequenceParallelLoRALinear(RowSequenceParallelLinear): |
335 |
| - def __init__( |
336 |
| - self, |
337 |
| - in_features: int, |
338 |
| - out_features: int, |
339 |
| - r: int = 0, |
340 |
| - lora_alpha: int = 1, |
341 |
| - lora_dropout: float = 0.0, |
342 |
| - rslora: bool = False, |
343 |
| - lora_plus_scale: float = 1.0, |
344 |
| - merge_weights: bool = True, |
345 |
| - use_quick_lora: bool = False, |
346 |
| - pissa: bool = False, |
347 |
| - **kwargs |
348 |
| - ): |
349 |
| - RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) |
350 |
| - if not isinstance(r, int) or r <= 0: |
351 |
| - raise ValueError("Lora rank r should be a positive integer") |
352 |
| - if pissa: |
353 |
| - raise ValueError("Pissa is not supported in model parallel by now") |
354 |
| - self.r = r |
355 |
| - self.lora_alpha = lora_alpha |
356 |
| - # Optional dropout |
357 |
| - if lora_dropout > 0.0: |
358 |
| - self.lora_dropout = nn.Dropout(p=lora_dropout) |
359 |
| - else: |
360 |
| - self.lora_dropout = lambda x: x |
361 |
| - # Mark the weight as unmerged |
362 |
| - self.merged = False |
363 |
| - self.merge_weights = merge_weights |
364 |
| - |
365 |
| - # compatible |
366 |
| - self.name = self._name |
367 |
| - |
368 |
| - # Actual trainable parameters |
369 |
| - self.lora_A = self.create_parameter( |
370 |
| - shape=[self.input_size_per_partition, r], |
371 |
| - dtype=self._dtype, |
372 |
| - is_bias=False, |
373 |
| - attr=paddle.ParamAttr( |
374 |
| - initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") |
375 |
| - ), |
376 |
| - ) |
377 |
| - self.lora_B = self.create_parameter( |
378 |
| - shape=[r, self.out_features], |
379 |
| - dtype=self._dtype, |
380 |
| - is_bias=False, |
381 |
| - attr=paddle.ParamAttr( |
382 |
| - initializer=paddle.nn.initializer.Constant(value=0.0), |
383 |
| - learning_rate=lora_plus_scale, |
384 |
| - ), |
385 |
| - ) |
386 |
| - |
387 |
| - self.lora_A.is_distributed = True |
388 |
| - self.lora_A.split_axis = 0 |
389 |
| - self.lora_B.is_distributed = False |
390 |
| - mark_as_sequence_parallel_parameter(self.lora_B) |
391 |
| - if not rslora: |
392 |
| - self.scaling = self.lora_alpha / self.r |
393 |
| - else: |
394 |
| - self.scaling = self.lora_alpha / math.sqrt(self.r) |
395 |
| - |
396 |
| - # Freezing the pre-trained weight matrix |
397 |
| - self.weight.stop_gradient = True |
398 |
| - self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 |
399 |
| - |
400 |
| - @property |
401 |
| - def use_quick_lora(self): |
402 |
| - # TODO(@gexiao): support qlora |
403 |
| - return False # self._use_quick_lora and self.training and not self.merged |
404 |
| - |
405 |
| - def train(self): |
406 |
| - super().train() |
407 |
| - if self.merge_weights and self.merged: |
408 |
| - # Make sure that the weights are not merged |
409 |
| - new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling |
410 |
| - self.weight.set_value(new_weight) |
411 |
| - self.merged = False |
412 |
| - |
413 |
| - def eval(self): |
414 |
| - super().eval() |
415 |
| - if self.merge_weights and not self.merged: |
416 |
| - # Merge the weights and mark it |
417 |
| - new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling |
418 |
| - self.weight.set_value(new_weight) |
419 |
| - self.merged = True |
420 |
| - |
421 |
| - def forward(self, x: paddle.Tensor): |
422 |
| - if not self.input_is_parallel: |
423 |
| - input_mp = mp_ops._c_split(x, group=self.model_parallel_group) |
424 |
| - else: |
425 |
| - input_mp = x |
426 |
| - |
427 |
| - if not is_mc2_valid(): |
428 |
| - output_parallel = self.linear(input_mp, self.weight, name=self._name) |
429 |
| - output_ = ReduceScatterOp.apply(output_parallel) |
430 |
| - result_mp = output_ + self.bias if self.bias is not None else output_ |
431 |
| - else: |
432 |
| - output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) |
433 |
| - result_mp = output_ + self.bias if self.bias is not None else output_ |
434 |
| - |
435 |
| - if not self.merged: |
436 |
| - input_mp = self.lora_dropout(input_mp) |
437 |
| - if not is_mc2_valid(): |
438 |
| - input_mp = input_mp @ self.lora_A |
439 |
| - input_mp = ReduceScatterOp.apply(input_mp) |
440 |
| - else: |
441 |
| - input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group) |
442 |
| - delta_mp = (input_mp @ self.lora_B) * self.scaling |
443 |
| - result_mp += delta_mp |
444 |
| - return result_mp |
445 |
| - |
446 |
| - def extra_repr(self): |
447 |
| - name = f", name={self.name}" if self.name else "" |
448 |
| - return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" |
449 |
| - |
450 |
| - |
451 | 301 | class ColumnParallelLoRALinear(ColumnParallelLinear):
|
452 | 302 | def __init__(
|
453 | 303 | self,
|
@@ -578,126 +428,6 @@ def extra_repr(self):
|
578 | 428 | return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
|
579 | 429 |
|
580 | 430 |
|
581 |
| -class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear): |
582 |
| - def __init__( |
583 |
| - self, |
584 |
| - in_features: int, |
585 |
| - out_features: int, |
586 |
| - r: int = 0, |
587 |
| - lora_alpha: int = 1, |
588 |
| - lora_dropout: float = 0.0, |
589 |
| - rslora: bool = False, |
590 |
| - lora_plus_scale: float = 1.0, |
591 |
| - merge_weights: bool = True, |
592 |
| - lora_A_weight_attr: Optional[paddle.ParamAttr] = None, |
593 |
| - use_quick_lora: bool = False, |
594 |
| - pissa: bool = False, |
595 |
| - **kwargs |
596 |
| - ): |
597 |
| - ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) |
598 |
| - if not isinstance(r, int) or r <= 0: |
599 |
| - raise ValueError("Lora rank r should be a positive integer") |
600 |
| - if pissa: |
601 |
| - raise ValueError("Pissa is not supported in model parallel by now") |
602 |
| - self.r = r |
603 |
| - self.lora_alpha = lora_alpha |
604 |
| - # Optional dropout |
605 |
| - if lora_dropout > 0.0: |
606 |
| - self.lora_dropout = nn.Dropout(p=lora_dropout) |
607 |
| - else: |
608 |
| - self.lora_dropout = lambda x: x |
609 |
| - # Mark the weight as unmerged |
610 |
| - self.merged = False |
611 |
| - self.merge_weights = merge_weights |
612 |
| - |
613 |
| - # compatible |
614 |
| - self.name = self._name |
615 |
| - |
616 |
| - # Actual trainable parameters |
617 |
| - self.lora_A = self.create_parameter( |
618 |
| - shape=[in_features, r], |
619 |
| - dtype=self._dtype, |
620 |
| - is_bias=False, |
621 |
| - attr=lora_A_weight_attr, |
622 |
| - ) |
623 |
| - self.lora_A.is_distributed = False |
624 |
| - mark_as_sequence_parallel_parameter(self.lora_A) |
625 |
| - |
626 |
| - self.lora_B = self.create_parameter( |
627 |
| - shape=[r, self.output_size_per_partition], |
628 |
| - dtype=self._dtype, |
629 |
| - is_bias=False, |
630 |
| - attr=paddle.ParamAttr( |
631 |
| - initializer=paddle.nn.initializer.Constant(value=0.0), |
632 |
| - learning_rate=lora_plus_scale, |
633 |
| - ), |
634 |
| - ) |
635 |
| - |
636 |
| - self.lora_B.is_distributed = True |
637 |
| - self.lora_B.split_axis = 1 |
638 |
| - if not rslora: |
639 |
| - self.scaling = self.lora_alpha / self.r |
640 |
| - else: |
641 |
| - self.scaling = self.lora_alpha / math.sqrt(self.r) |
642 |
| - |
643 |
| - # Freezing the pre-trained weight matrix |
644 |
| - self.weight.stop_gradient = True |
645 |
| - self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 |
646 |
| - |
647 |
| - @property |
648 |
| - def use_quick_lora(self): |
649 |
| - # TODO(@gexiao): support qlora |
650 |
| - return False # self._use_quick_lora and self.training and not self.merged |
651 |
| - |
652 |
| - def train(self): |
653 |
| - super().train() |
654 |
| - if self.merge_weights and self.merged: |
655 |
| - # Make sure that the weights are not merged |
656 |
| - new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling |
657 |
| - self.weight.set_value(new_weight) |
658 |
| - self.merged = False |
659 |
| - |
660 |
| - def eval(self): |
661 |
| - super().eval() |
662 |
| - if self.merge_weights and not self.merged: |
663 |
| - # Merge the weights and mark it |
664 |
| - new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling |
665 |
| - self.weight.set_value(new_weight) |
666 |
| - self.merged = True |
667 |
| - |
668 |
| - def forward(self, x: paddle.Tensor): |
669 |
| - if not is_mc2_valid(): |
670 |
| - if self.is_mp: |
671 |
| - input_parallel = AllGatherOp.apply(x) |
672 |
| - else: |
673 |
| - input_parallel = x |
674 |
| - result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name) |
675 |
| - else: |
676 |
| - result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group) |
677 |
| - if self.bias is not None: |
678 |
| - result_mp += self.bias |
679 |
| - |
680 |
| - if not self.merged: |
681 |
| - input_a = self.lora_dropout(x) @ self.lora_A |
682 |
| - if not is_mc2_valid(): |
683 |
| - input_a = AllGatherOp.apply(input_a) |
684 |
| - delta_mp = (input_a @ self.lora_B) * self.scaling |
685 |
| - else: |
686 |
| - input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) |
687 |
| - delta_mp = input_a * self.scaling |
688 |
| - result_mp += delta_mp |
689 |
| - |
690 |
| - if self.gather_output and self.is_mp: |
691 |
| - result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) |
692 |
| - else: |
693 |
| - result = result_mp |
694 |
| - return result |
695 |
| - |
696 |
| - def extra_repr(self): |
697 |
| - name = f", name={self.name}" if self.name else "" |
698 |
| - return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" |
699 |
| - |
700 |
| - |
701 | 431 | class LoRAMergedLinear(nn.Linear):
|
702 | 432 | # LoRA implemented in a dense layer with merged linear weights for q, k, v
|
703 | 433 | def __init__(
|
|
0 commit comments