diff --git a/accessory/model/LLM/llama_adapter.py b/accessory/model/LLM/llama_adapter.py index 74e8b572..85a75650 100644 --- a/accessory/model/LLM/llama_adapter.py +++ b/accessory/model/LLM/llama_adapter.py @@ -24,6 +24,7 @@ import configs.global_configs if configs.global_configs.USE_FLASH_ATTENTION: from flash_attn import flash_attn_func +from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) @@ -50,47 +51,6 @@ class ModelArgs: bias_tuning: bool = False # bias -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/accessory/model/LLM/llama_peft.py b/accessory/model/LLM/llama_peft.py index 4f0b5dc6..468a626d 100644 --- a/accessory/model/LLM/llama_peft.py +++ b/accessory/model/LLM/llama_peft.py @@ -1,280 +1,43 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from typing import Optional, Tuple from dataclasses import dataclass -import math import functools -import torch -from torch import nn -import torch.nn.functional as F +from ..peft import wrap_lora -import fairscale.nn.model_parallel.initialize as fs_init -from fairscale.nn.model_parallel.layers import ( - ParallelEmbedding, - RowParallelLinear, - ColumnParallelLinear +from .llama import ( + ModelArgs as LLaMAModelArgs, + Transformer as LLaMATransformer, ) -from ..peft import LoraColumnParallelLinear, LoraRowParallelLinear - -from apex.normalization import FusedRMSNorm as RMSNorm -import open_clip - -import configs.global_configs -if configs.global_configs.USE_FLASH_ATTENTION: - from flash_attn import flash_attn_func - -default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) - @dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - - max_batch_size: int = 32 - max_seq_len: int = 2048 - +class ModelArgs(LLaMAModelArgs): lora_rank: int = -1 # lora - bias_tuning: bool = True # bias -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - model_parallel_size = fs_init.get_model_parallel_world_size() - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads - - self.wq = LoraColumnParallelLinear( - args.dim, - args.n_heads * self.head_dim, - bias=args.bias_tuning, - gather_output=False, - init_method=default_linear_init, - lora_rank=args.lora_rank - ) - self.wk = LoraColumnParallelLinear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=args.bias_tuning, - gather_output=False, - init_method=default_linear_init, - lora_rank=args.lora_rank - ) - self.wv = LoraColumnParallelLinear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=args.bias_tuning, - gather_output=False, - init_method=default_linear_init, - lora_rank=args.lora_rank - ) - self.wo = LoraRowParallelLinear( - args.n_heads * self.head_dim, - args.dim, - bias=args.bias_tuning, - input_is_parallel=True, - init_method=default_linear_init, - lora_rank=args.lora_rank - ) - - self.args = args - - self.flash = configs.global_configs.USE_FLASH_ATTENTION - - def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): - bsz, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - keys = xk - values = xv - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - if self.flash: - output = flash_attn_func(xq, keys, values, dropout_p=0.0, causal=True) - output = output.contiguous().view(bsz, seqlen, -1) - else: - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - output = F.scaled_dot_product_attention(xq, keys, values, dropout_p=0.0, mask=mask) - - output = output.transpose( - 1, 2 - ).contiguous().view(bsz, seqlen, -1) - - return self.wo(output) - - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - args: ModelArgs, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = LoraColumnParallelLinear( - dim, hidden_dim, bias=args.bias_tuning, gather_output=False, - init_method=default_linear_init, lora_rank=args.lora_rank - ) - self.w2 = LoraRowParallelLinear( - hidden_dim, dim, bias=args.bias_tuning, input_is_parallel=True, - init_method=default_linear_init, lora_rank=args.lora_rank - ) - self.w3 = LoraColumnParallelLinear( - dim, hidden_dim, bias=args.bias_tuning, gather_output=False, - init_method=default_linear_init, lora_rank=args.lora_rank - ) - - # @torch.compile - def _silu_gating(self, x, y): - return F.silu(x) * y - - def forward(self, x): - return self.w2(self._silu_gating(self.w1(x), self.w3(x))) - - -class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - self.attention = Attention(args) - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - args=args - ) - self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def _forward_ffn(self, h): - return h + self.feed_forward(self.ffn_norm(h)) - - def _forward_attention(self, x, start_pos, freqs_cis, mask): - return x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) - - def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): - h = self._forward_attention(x, start_pos, freqs_cis, mask) - out = self._forward_ffn(h) - return out - - -class Transformer(nn.Module): - is_peft = True - def __init__(self, params: ModelArgs, with_visual=False): - super().__init__() - self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers - self.tok_embeddings = ParallelEmbedding( - params.vocab_size, params.dim, init_method=default_linear_init - ) - - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) - - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = ColumnParallelLinear( - params.dim, params.vocab_size, bias=False, init_method=default_linear_init - ) - - self.freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 - ) - - self.image_words = 0 - if with_visual: - print("build llama model with clip") - torch.set_default_tensor_type(torch.cuda.HalfTensor) - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') - torch.set_default_tensor_type(torch.FloatTensor) - for name, param in self.clip.named_parameters(): - param.requires_grad = False - in_dim = self.clip.visual.proj.shape[1] - # in_dim = 3 - self.clip_proj = nn.Linear(in_dim, params.dim) - self.clip_proj_norm = nn.LayerNorm(params.dim) - self.image_words = 257 - +class Transformer(LLaMATransformer): + def __init__(self, args: ModelArgs, with_visual: bool = False) -> None: + super().__init__(args, with_visual) + self._setup_peft() self.set_default_trainability() + def _setup_peft(self): + wrap_lora_with_args = functools.partial( + wrap_lora, + lora_rank=self.params.lora_rank, + bias=self.params.bias_tuning, + ) + def wrap_attn(attn): + attn.wq = wrap_lora_with_args(attn.wq) + attn.wk = wrap_lora_with_args(attn.wk) + attn.wv = wrap_lora_with_args(attn.wv) + attn.wo = wrap_lora_with_args(attn.wo) + def wrap_ffn(ffn): + ffn.w1 = wrap_lora_with_args(ffn.w1) + ffn.w2 = wrap_lora_with_args(ffn.w2) + ffn.w3 = wrap_lora_with_args(ffn.w3) + for layer in self.layers: + wrap_attn(layer.attention) + wrap_ffn(layer.feed_forward) def get_trainable_params(self): trainable = {} @@ -283,10 +46,8 @@ def get_trainable_params(self): trainable_key_words = ['norm', 'bias', 'lora'] if any([_ in name for _ in trainable_key_words]): trainable[name] = para - return trainable - def set_default_trainability(self): for key, value in self.named_parameters(): value.requires_grad = False @@ -295,86 +56,3 @@ def set_default_trainability(self): value.data = value.data.float() value.requires_grad = True - - @torch.no_grad() - def clip_encode_image(self, x): - # modified from CLIP - x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid] - # shape = [*, width, grid ** 2] - x = x.reshape(x.shape[0], x.shape[1], -1) - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, - x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.clip.visual.positional_embedding.to(x.dtype) - x = self.clip.visual.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.clip.visual.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - # preserve all spatial tokens - x = self.clip.visual.ln_post(x[:, :, :]) - - if self.clip.visual.proj is not None: - x = x @ self.clip.visual.proj - - return x - - - def encode_image(self, image): - # return self.patch_embed(image) - image_tokens = self.clip_encode_image(image) - image_tokens = self.clip_proj_norm(self.clip_proj(image_tokens)) - return image_tokens - - def forward(self, examples, image=None): - _bsz, seqlen = examples.shape - h = self.tok_embeddings(examples) - self.freqs_cis = self.freqs_cis.to(h.device) - start_pos = 0 - - if image is not None: - image_tokens = self.encode_image(image) - h = torch.cat((image_tokens, h), dim=1) - start_pos = image_tokens.shape[1] - seqlen = h.shape[1] - - # print(f"image: {start_pos}, text: {seqlen - start_pos}, seq_len: {seqlen}") - - freqs_cis = self.freqs_cis[:seqlen] - mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device) - mask = torch.triu(mask, diagonal=1).type_as(h) - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h) - output = self.output(h[:, start_pos:, :]) - return output - - - @torch.inference_mode() - def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None): - assert start_pos==0 - _bsz, seqlen = tokens.shape - h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.to(h.device) - - if image is not None: - image_tokens = self.encode_image(image) - h = torch.cat((image_tokens, h), dim=1) - start_pos = start_pos + image_tokens.shape[1] - seqlen = h.shape[1] - - freqs_cis = self.freqs_cis[:seqlen] - - mask = None - if seqlen > 1: - mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) - mask = torch.triu(mask, diagonal=1).type_as(h) - mask[:, :, :, :start_pos] = 0 - - - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h) - output = self.output(h[:, -1, :]) # only compute last logits - return output.float() diff --git a/accessory/model/peft.py b/accessory/model/peft.py index 93002cc4..1b461396 100644 --- a/accessory/model/peft.py +++ b/accessory/model/peft.py @@ -1,23 +1,8 @@ -from typing import Callable, Optional - import torch import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Linear, Parameter, init -from torch import Tensor - -from timm.models.layers import trunc_normal_, lecun_normal_, to_2tuple +from timm.models.layers import trunc_normal_ -from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size -from fairscale.nn.model_parallel.mappings import ( - copy_to_model_parallel_region, - gather_from_model_parallel_region, - reduce_from_model_parallel_region, - scatter_to_model_parallel_region, -) -from fairscale.nn.model_parallel.utils import VocabUtility, divide_and_check_no_remainder -from fairscale.nn.model_parallel.layers import _initialize_affine_weight from fairscale.nn.model_parallel.layers import ( RowParallelLinear, ColumnParallelLinear, @@ -25,184 +10,59 @@ class LoraColumnParallelLinear(ColumnParallelLinear): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. + """ColumnParallelLinear extended with LoRA support""" - Arguments: - in_features: first dimension of matrix A. - out_features: second dimension of matrix A. - bias: If true, add bias - gather_output: If true, call all-gether on output and make Y avaiable - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - gather_output: bool = True, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - stride: int = 1, - keep_master_weight_for_test: bool = False, - lora_rank=0 - ) -> None: - nn.Module.__init__(self) - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - world_size = get_model_parallel_world_size() - self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size) - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) - if bias: - self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) + def __init__(self, *args, **kwargs) -> None: + self.lora_rank = kwargs.pop("lora_rank", 0) + super().__init__(*args, **kwargs) - # Initialize weight. - self.master_weight = _initialize_affine_weight( - self.weight, - self.out_features, - self.in_features, - self.output_size_per_partition, - 0, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - ) - - self.lora_rank = lora_rank if self.lora_rank > 0: - # if world_size > 1: - # raise NotImplemented("Lora with model parallel with change the original behavior, not yet supported") self.lora_a = nn.Linear(self.in_features, self.lora_rank, bias=False) trunc_normal_(self.lora_a.weight, std=.02) - self.lora_b = ColumnParallelLinear(self.lora_rank, self.out_features, bias=False, gather_output=gather_output) + self.lora_b = ColumnParallelLinear( + self.lora_rank, self.out_features, bias=False, + gather_output=self.gather_output, + ) nn.init.zeros_(self.lora_b.weight) else: self.lora_a = None self.lora_b = None - def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1) - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) - if self.lora_a is not None: - modification = self.lora_b(self.lora_a(input_)) - else: - modification = None - - if self.gather_output: - # All-gather across the partitions. - output = gather_from_model_parallel_region(output_parallel) - else: - output = output_parallel - - if modification is not None: - output = output + modification + output = super().forward(input_) + if self.lora_rank > 0: + output += self.lora_b(self.lora_a(input_)) return output + @staticmethod + def from_non_lora(layer: ColumnParallelLinear, **kwargs) -> LoraColumnParallelLinear: + new_layer_kwargs = dict( + in_features=layer.in_features, + out_features=layer.out_features, + bias=layer.bias is not None, + gather_output=layer.gather_output, + init_method=lambda x: x, + keep_master_weight_for_test=layer.master_weight is not None, + ) + new_layer_kwargs.update(kwargs) + layer_with_lora = LoraColumnParallelLinear(**new_layer_kwargs) + layer_with_lora.weight.data.copy_(layer.weight) + if layer_with_lora.bias is not None and layer.bias is not None: + layer_with_lora.bias.data.copy_(layer.bias) + return layer_with_lora class LoraRowParallelLinear(RowParallelLinear): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - in_features: first dimension of matrix A. - out_features: second dimension of matrix A. - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - input_is_parallel: bool = False, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - stride: int = 1, - keep_master_weight_for_test: bool = False, - lora_rank = 0 - ): - nn.Module.__init__(self) - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.input_is_parallel = input_is_parallel - # Divide the weight matrix along the last dimension. - world_size = get_model_parallel_world_size() - self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size) + """RowParallelLinear with LoRA support""" - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) - if bias: - self.bias = Parameter(torch.Tensor(self.out_features)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) - - # Initialize weight. - self.master_weight = _initialize_affine_weight( - self.weight, - self.out_features, - self.in_features, - self.input_size_per_partition, - 1, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - ) + def __init__(self, *args, **kwargs) -> None: + self.lora_rank = kwargs.pop("lora_rank", 0) + super().__init__(self, *args, **kwargs) - self.lora_rank = lora_rank if self.lora_rank > 0: - # if world_size > 1: - # raise NotImplemented("Lora with model parallel with change the original behavior, not yet supported") - self.lora_a = RowParallelLinear(self.in_features, self.lora_rank, bias=False, input_is_parallel=True) + self.lora_a = RowParallelLinear( + self.in_features, self.lora_rank, bias=False, + input_is_parallel=self.input_is_parallel, + ) trunc_normal_(self.lora_a.weight, std=.02) self.lora_b = nn.Linear(self.lora_rank, self.out_features, bias=False) nn.init.zeros_(self.lora_b.weight) @@ -210,24 +70,35 @@ def __init__( self.lora_a = None self.lora_b = None - def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data) - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - input_parallel = scatter_to_model_parallel_region(input_) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight) - # All-reduce across all the partitions. - output_ = reduce_from_model_parallel_region(output_parallel) - if self.lora_a is not None: - modification = self.lora_b(self.lora_a(input_parallel)) - output_ = output_ + modification - if self.bias is not None: - output = output_ + self.bias - else: - output = output_ - return output \ No newline at end of file + output = super().forward(input_) + if self.lora_rank > 0: + output += self.lora_b(self.lora_a(input_)) + return output + + @staticmethod + def from_non_lora(layer: RowParallelLinear, **kwargs) -> LoraRowParallelLinear: + new_layer_kwargs = dict( + in_features=layer.in_features, + out_features=layer.out_features, + bias=layer.bias is not None, + input_is_parallel=layer.input_is_parallel, + init_method=lambda x: x, + keep_master_weight_for_test=layer.master_weight is not None, + ) + new_layer_kwargs.update(kwargs) + layer_with_lora = LoraRowParallelLinear(**new_layer_kwargs) + layer_with_lora.weight.data.copy_(layer.weight) + if layer_with_lora.bias is not None and layer.bias is not None: + layer_with_lora.bias.data.copy_(layer.bias) + return layer_with_lora + +def wrap_lora(layer: nn.Module, **kwargs): + base_module_to_lora_module = [ + (ColumnParallelLinear, LoraColumnParallelLinear), + (RowParallelLinear, LoraRowParallelLinear), + ] + for base_module, lora_module in base_module_to_lora_module: + if isinstance(layer, base_module): + return lora_module.from_non_lora(layer, **kwargs) + raise NotImplementedError(f"LoRA wrapping for layer of type {type(layer)} is not implemented.")