|
27 | 27 |
|
28 | 28 | from .lora_quick_layers import quick_lora
|
29 | 29 |
|
30 |
| -if "npu" in paddle.device.get_all_custom_device_type(): |
| 30 | + |
| 31 | +def is_mc2_valid(): |
| 32 | + return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")) |
| 33 | + |
| 34 | + |
| 35 | +if is_mc2_valid(): |
| 36 | + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
| 37 | + MC2ColumnSeqParallelLinear, |
| 38 | + MC2RowSeqParallelLinear, |
| 39 | + ) |
| 40 | + |
31 | 41 | from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
|
32 | 42 | else:
|
33 | 43 | MC2LoRaRowParallelLinear = None
|
34 | 44 | MC2LoRaColumnParallelLinear = None
|
| 45 | + MC2ColumnSeqParallelLinear = None |
| 46 | + MC2RowSeqParallelLinear = None |
| 47 | + |
| 48 | + |
| 49 | +try: |
| 50 | + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
| 51 | + AllGatherOp, |
| 52 | + ColumnSequenceParallelLinear, |
| 53 | + ReduceScatterOp, |
| 54 | + RowSequenceParallelLinear, |
| 55 | + mark_as_sequence_parallel_parameter, |
| 56 | + ) |
| 57 | +except: |
| 58 | + |
| 59 | + class ColumnSequenceParallelLinear: |
| 60 | + pass |
| 61 | + |
| 62 | + class RowSequenceParallelLinear: |
| 63 | + pass |
| 64 | + |
| 65 | + AllGatherOp = None |
| 66 | + ReduceScatterOp = None |
| 67 | + mark_as_sequence_parallel_parameter = None |
35 | 68 |
|
36 | 69 |
|
37 | 70 | class LoRALinear(nn.Linear):
|
@@ -298,6 +331,123 @@ def extra_repr(self):
|
298 | 331 | return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
|
299 | 332 |
|
300 | 333 |
|
| 334 | +class RowSequenceParallelLoRALinear(RowSequenceParallelLinear): |
| 335 | + def __init__( |
| 336 | + self, |
| 337 | + in_features: int, |
| 338 | + out_features: int, |
| 339 | + r: int = 0, |
| 340 | + lora_alpha: int = 1, |
| 341 | + lora_dropout: float = 0.0, |
| 342 | + rslora: bool = False, |
| 343 | + lora_plus_scale: float = 1.0, |
| 344 | + merge_weights: bool = True, |
| 345 | + use_quick_lora: bool = False, |
| 346 | + pissa: bool = False, |
| 347 | + **kwargs |
| 348 | + ): |
| 349 | + RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) |
| 350 | + if not isinstance(r, int) or r <= 0: |
| 351 | + raise ValueError("Lora rank r should be a positive integer") |
| 352 | + if pissa: |
| 353 | + raise ValueError("Pissa is not supported in model parallel by now") |
| 354 | + self.r = r |
| 355 | + self.lora_alpha = lora_alpha |
| 356 | + # Optional dropout |
| 357 | + if lora_dropout > 0.0: |
| 358 | + self.lora_dropout = nn.Dropout(p=lora_dropout) |
| 359 | + else: |
| 360 | + self.lora_dropout = lambda x: x |
| 361 | + # Mark the weight as unmerged |
| 362 | + self.merged = False |
| 363 | + self.merge_weights = merge_weights |
| 364 | + |
| 365 | + # compatible |
| 366 | + self.name = self._name |
| 367 | + |
| 368 | + # Actual trainable parameters |
| 369 | + self.lora_A = self.create_parameter( |
| 370 | + shape=[self.input_size_per_partition, r], |
| 371 | + dtype=self._dtype, |
| 372 | + is_bias=False, |
| 373 | + attr=paddle.ParamAttr( |
| 374 | + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") |
| 375 | + ), |
| 376 | + ) |
| 377 | + self.lora_B = self.create_parameter( |
| 378 | + shape=[r, self.out_features], |
| 379 | + dtype=self._dtype, |
| 380 | + is_bias=False, |
| 381 | + attr=paddle.ParamAttr( |
| 382 | + initializer=paddle.nn.initializer.Constant(value=0.0), |
| 383 | + learning_rate=lora_plus_scale, |
| 384 | + ), |
| 385 | + ) |
| 386 | + |
| 387 | + self.lora_A.is_distributed = True |
| 388 | + self.lora_A.split_axis = 0 |
| 389 | + self.lora_B.is_distributed = False |
| 390 | + mark_as_sequence_parallel_parameter(self.lora_B) |
| 391 | + if not rslora: |
| 392 | + self.scaling = self.lora_alpha / self.r |
| 393 | + else: |
| 394 | + self.scaling = self.lora_alpha / math.sqrt(self.r) |
| 395 | + |
| 396 | + # Freezing the pre-trained weight matrix |
| 397 | + self.weight.stop_gradient = True |
| 398 | + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 |
| 399 | + |
| 400 | + @property |
| 401 | + def use_quick_lora(self): |
| 402 | + # TODO(@gexiao): support qlora |
| 403 | + return False # self._use_quick_lora and self.training and not self.merged |
| 404 | + |
| 405 | + def train(self): |
| 406 | + super().train() |
| 407 | + if self.merge_weights and self.merged: |
| 408 | + # Make sure that the weights are not merged |
| 409 | + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling |
| 410 | + self.weight.set_value(new_weight) |
| 411 | + self.merged = False |
| 412 | + |
| 413 | + def eval(self): |
| 414 | + super().eval() |
| 415 | + if self.merge_weights and not self.merged: |
| 416 | + # Merge the weights and mark it |
| 417 | + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling |
| 418 | + self.weight.set_value(new_weight) |
| 419 | + self.merged = True |
| 420 | + |
| 421 | + def forward(self, x: paddle.Tensor): |
| 422 | + if not self.input_is_parallel: |
| 423 | + input_mp = mp_ops._c_split(x, group=self.model_parallel_group) |
| 424 | + else: |
| 425 | + input_mp = x |
| 426 | + |
| 427 | + if not is_mc2_valid(): |
| 428 | + output_parallel = self.linear(input_mp, self.weight, name=self._name) |
| 429 | + output_ = ReduceScatterOp.apply(output_parallel) |
| 430 | + result_mp = output_ + self.bias if self.bias is not None else output_ |
| 431 | + else: |
| 432 | + output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) |
| 433 | + result_mp = output_ + self.bias if self.bias is not None else output_ |
| 434 | + |
| 435 | + if not self.merged: |
| 436 | + input_mp = self.lora_dropout(input_mp) |
| 437 | + if not is_mc2_valid(): |
| 438 | + input_mp = input_mp @ self.lora_A |
| 439 | + input_mp = ReduceScatterOp.apply(input_mp) |
| 440 | + else: |
| 441 | + input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group) |
| 442 | + delta_mp = (input_mp @ self.lora_B) * self.scaling |
| 443 | + result_mp += delta_mp |
| 444 | + return result_mp |
| 445 | + |
| 446 | + def extra_repr(self): |
| 447 | + name = f", name={self.name}" if self.name else "" |
| 448 | + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" |
| 449 | + |
| 450 | + |
301 | 451 | class ColumnParallelLoRALinear(ColumnParallelLinear):
|
302 | 452 | def __init__(
|
303 | 453 | self,
|
@@ -428,6 +578,126 @@ def extra_repr(self):
|
428 | 578 | return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
|
429 | 579 |
|
430 | 580 |
|
| 581 | +class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear): |
| 582 | + def __init__( |
| 583 | + self, |
| 584 | + in_features: int, |
| 585 | + out_features: int, |
| 586 | + r: int = 0, |
| 587 | + lora_alpha: int = 1, |
| 588 | + lora_dropout: float = 0.0, |
| 589 | + rslora: bool = False, |
| 590 | + lora_plus_scale: float = 1.0, |
| 591 | + merge_weights: bool = True, |
| 592 | + lora_A_weight_attr: Optional[paddle.ParamAttr] = None, |
| 593 | + use_quick_lora: bool = False, |
| 594 | + pissa: bool = False, |
| 595 | + **kwargs |
| 596 | + ): |
| 597 | + ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) |
| 598 | + if not isinstance(r, int) or r <= 0: |
| 599 | + raise ValueError("Lora rank r should be a positive integer") |
| 600 | + if pissa: |
| 601 | + raise ValueError("Pissa is not supported in model parallel by now") |
| 602 | + self.r = r |
| 603 | + self.lora_alpha = lora_alpha |
| 604 | + # Optional dropout |
| 605 | + if lora_dropout > 0.0: |
| 606 | + self.lora_dropout = nn.Dropout(p=lora_dropout) |
| 607 | + else: |
| 608 | + self.lora_dropout = lambda x: x |
| 609 | + # Mark the weight as unmerged |
| 610 | + self.merged = False |
| 611 | + self.merge_weights = merge_weights |
| 612 | + |
| 613 | + # compatible |
| 614 | + self.name = self._name |
| 615 | + |
| 616 | + # Actual trainable parameters |
| 617 | + self.lora_A = self.create_parameter( |
| 618 | + shape=[in_features, r], |
| 619 | + dtype=self._dtype, |
| 620 | + is_bias=False, |
| 621 | + attr=lora_A_weight_attr, |
| 622 | + ) |
| 623 | + self.lora_A.is_distributed = False |
| 624 | + mark_as_sequence_parallel_parameter(self.lora_A) |
| 625 | + |
| 626 | + self.lora_B = self.create_parameter( |
| 627 | + shape=[r, self.output_size_per_partition], |
| 628 | + dtype=self._dtype, |
| 629 | + is_bias=False, |
| 630 | + attr=paddle.ParamAttr( |
| 631 | + initializer=paddle.nn.initializer.Constant(value=0.0), |
| 632 | + learning_rate=lora_plus_scale, |
| 633 | + ), |
| 634 | + ) |
| 635 | + |
| 636 | + self.lora_B.is_distributed = True |
| 637 | + self.lora_B.split_axis = 1 |
| 638 | + if not rslora: |
| 639 | + self.scaling = self.lora_alpha / self.r |
| 640 | + else: |
| 641 | + self.scaling = self.lora_alpha / math.sqrt(self.r) |
| 642 | + |
| 643 | + # Freezing the pre-trained weight matrix |
| 644 | + self.weight.stop_gradient = True |
| 645 | + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 |
| 646 | + |
| 647 | + @property |
| 648 | + def use_quick_lora(self): |
| 649 | + # TODO(@gexiao): support qlora |
| 650 | + return False # self._use_quick_lora and self.training and not self.merged |
| 651 | + |
| 652 | + def train(self): |
| 653 | + super().train() |
| 654 | + if self.merge_weights and self.merged: |
| 655 | + # Make sure that the weights are not merged |
| 656 | + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling |
| 657 | + self.weight.set_value(new_weight) |
| 658 | + self.merged = False |
| 659 | + |
| 660 | + def eval(self): |
| 661 | + super().eval() |
| 662 | + if self.merge_weights and not self.merged: |
| 663 | + # Merge the weights and mark it |
| 664 | + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling |
| 665 | + self.weight.set_value(new_weight) |
| 666 | + self.merged = True |
| 667 | + |
| 668 | + def forward(self, x: paddle.Tensor): |
| 669 | + if not is_mc2_valid(): |
| 670 | + if self.is_mp: |
| 671 | + input_parallel = AllGatherOp.apply(x) |
| 672 | + else: |
| 673 | + input_parallel = x |
| 674 | + result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name) |
| 675 | + else: |
| 676 | + result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group) |
| 677 | + if self.bias is not None: |
| 678 | + result_mp += self.bias |
| 679 | + |
| 680 | + if not self.merged: |
| 681 | + input_a = self.lora_dropout(x) @ self.lora_A |
| 682 | + if not is_mc2_valid(): |
| 683 | + input_a = AllGatherOp.apply(input_a) |
| 684 | + delta_mp = (input_a @ self.lora_B) * self.scaling |
| 685 | + else: |
| 686 | + input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) |
| 687 | + delta_mp = input_a * self.scaling |
| 688 | + result_mp += delta_mp |
| 689 | + |
| 690 | + if self.gather_output and self.is_mp: |
| 691 | + result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) |
| 692 | + else: |
| 693 | + result = result_mp |
| 694 | + return result |
| 695 | + |
| 696 | + def extra_repr(self): |
| 697 | + name = f", name={self.name}" if self.name else "" |
| 698 | + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" |
| 699 | + |
| 700 | + |
431 | 701 | class LoRAMergedLinear(nn.Linear):
|
432 | 702 | # LoRA implemented in a dense layer with merged linear weights for q, k, v
|
433 | 703 | def __init__(
|
|
0 commit comments