Skip to content
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ runtime.python_library(
name = "llama_transformer",
srcs = [
"llama_transformer.py",
"lora.py",
"rope.py",
"attention.py",
"model_args.py",
Expand Down
77 changes: 69 additions & 8 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,20 @@ def update(

@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
):
"""
Multi-head attention layer.

Args:
args (ModelArgs): Model configuration parameters.
layer_id (int): Layer index.
rope (Rope): Rotary position embedding module.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand All @@ -350,16 +363,64 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
self.wq = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_heads * args.head_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None
and "q_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wk = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_kv_heads * args.head_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None
and "k_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wv = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wv = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_kv_heads * args.head_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None
and "v_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wo = (
LoRALinear(
in_dim=args.n_kv_heads * args.head_dim,
out_dim=args.dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None
and "output_proj" in args.target_modules
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

self.layer_id = layer_id

Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ForwardOptions,
)

from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope
Expand Down
48 changes: 48 additions & 0 deletions examples/models/llama/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


class LoRALinear(nn.Module):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""

def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.rank = rank
self.alpha = alpha
self.use_bias = use_bias
self.dropout = dropout

linear = nn.Linear(in_dim, out_dim, bias=use_bias)
weight = linear.weight
bias = linear.bias if self.use_bias else None
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out
10 changes: 10 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ class ModelArgs:
eos_count: int = 2

quantization_args: Optional[dict] = None
# LoRA for QAT.
lora_args: Optional[dict] = None

# LoRA arguments to set up a LoRA inference model.
# These arguments come directly from a torchtune LoRA config.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Eg. q_proj, k_proj, v_proj, output_proj
target_modules: Optional[list] = None
peft_type: Optional[str] = None # PEFT type.
base_model_name_or_path: Optional[str] = None # Base model name or path.

def __post_init__(self):
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down
Loading