diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 95d57e12f5a..9ea683e4174 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( name = "llama_transformer", srcs = [ "llama_transformer.py", + "lora.py", "rope.py", "attention.py", "model_args.py", diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index aa53b330837..6f23456eaaa 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope @@ -325,7 +326,20 @@ def update( @register_attention("mha") class AttentionMHA(Attention): - def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + ): + """ + Multi-head attention layer. + + Args: + args (ModelArgs): Model configuration parameters. + layer_id (int): Layer index. + rope (Rope): Rotary position embedding module. + """ super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -350,16 +364,60 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) - self.wq = nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + self.wq = ( + LoRALinear( + in_dim=args.dim, + out_dim=args.n_heads * args.head_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "q_proj" in args.target_modules + else nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) ) - self.wk = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + self.wk = ( + LoRALinear( + in_dim=args.dim, + out_dim=args.n_kv_heads * args.head_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "k_proj" in args.target_modules + else nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) ) - self.wv = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + self.wv = ( + LoRALinear( + in_dim=args.dim, + out_dim=args.n_kv_heads * args.head_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "v_proj" in args.target_modules + else nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + ) + self.wo = ( + LoRALinear( + in_dim=args.n_kv_heads * args.head_dim, + out_dim=args.dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=args.attention_qkv_bias, + ) + if args.target_modules is not None and "output_proj" in args.target_modules + else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) ) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py new file mode 100644 index 00000000000..12c1c4e5d68 --- /dev/null +++ b/examples/models/llama/lora.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +class LoRALinear(nn.Module): + """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `.""" + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.use_bias = use_bias + self.dropout = dropout + + linear = nn.Linear(in_dim, out_dim, bias=use_bias) + weight = linear.weight + bias = linear.bias if self.use_bias else None + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.nn.functional.linear(x, self.weight, self.bias) + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + + return out + lora_out diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 5734cd66ef7..18acda9fe93 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -55,8 +55,18 @@ class ModelArgs: eos_count: int = 2 quantization_args: Optional[dict] = None + # LoRA for QAT. lora_args: Optional[dict] = None + # LoRA arguments to set up a LoRA inference model. + # These arguments come directly from a torchtune LoRA config. + r: Optional[int] = None # Rank. + lora_alpha: Optional[int] = None # Alpha. + # Eg. q_proj, k_proj, v_proj, output_proj + target_modules: Optional[list] = None + peft_type: Optional[str] = None # PEFT type. + base_model_name_or_path: Optional[str] = None # Base model name or path. + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads