Skip to content
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ runtime.python_library(
name = "llama_transformer",
srcs = [
"llama_transformer.py",
"lora.py",
"rope.py",
"attention.py",
"model_args.py",
Expand Down
53 changes: 45 additions & 8 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,28 @@ def update(

@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
wq: Optional[nn.Module] = None,
wk: Optional[nn.Module] = None,
wv: Optional[nn.Module] = None,
wo: Optional[nn.Module] = None,
):
"""
Multi-head attention layer.

Args:
args (ModelArgs): Model configuration parameters.
layer_id (int): Layer index.
rope (Rope): Rotary position embedding module.
wq (Optional[nn.Module]): Query projection module. If None, use regular nn.Linear.
wk (Optional[nn.Module]): Key projection module. If None, use regular nn.Linear.
wv (Optional[nn.Module]): Value projection module. If None, use regular nn.Linear.
wo (Optional[nn.Module]): Output projection module. If None, use regular nn.Linear.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand All @@ -350,16 +371,32 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
self.wq = (
wq
if wq is not None
else nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wk = (
wk
if wk is not None
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wv = (
wv
if wv is not None
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wv = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wo = (
wo
if wo is not None
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

self.layer_id = layer_id

Expand Down
63 changes: 62 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ForwardOptions,
)

from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope
Expand Down Expand Up @@ -255,7 +256,67 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
layers = torch.nn.ModuleList()
cls = ATTENTION_REGISTRY[model_args.attention_type]
for layer_id in range(model_args.n_layers):
attention = cls(model_args, layer_id, rope)
wq = (
LoRALinear(
in_dim=model_args.dim,
out_dim=model_args.n_heads * model_args.head_dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "q_proj" in model_args.target_modules
else None
)

wk = (
LoRALinear(
in_dim=model_args.dim,
out_dim=model_args.n_kv_heads * model_args.head_dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "k_proj" in model_args.target_modules
else None
)

wv = (
LoRALinear(
in_dim=model_args.dim,
out_dim=model_args.n_kv_heads * model_args.head_dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
else None
)

wo = (
LoRALinear(
in_dim=model_args.n_kv_heads * model_args.head_dim,
out_dim=model_args.dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "output_proj" in model_args.target_modules
else None
)
if model_args.attention_type == "static":
# Static attention constructs ModuleLists for qkvo and
# populates them with MHA attention layers; do not pass in
# wq, wk, wv, wo.
attention = cls(model_args, layer_id, rope)
Copy link
Contributor

@JacobSzwejbka JacobSzwejbka Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you doing this in llama_transformer at all? You have model args passed to the constructor couldnt you just construct these tensors in the constructor rather then passing them to init?

Not having to add attention specific logic to llama_transformer.py is the intention of the attention registry, so if we have to break that abstraction I think we need to elaborate why

Copy link
Contributor Author

@lucylq lucylq Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to make building a transformer composable, e.g., you can pass in any type of attention, with any type of linears.

If construct them inside the transformer constructor, then the transformer has to understand what all the different options are. i.e., we didn't want to put LoRA logic inside MHA attention itself, as it isn't tied to MHA specifically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attention already has to understand what all the different options are. Thats why model args is passed into it no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attention already has to understand what all the different options are. Thats why model args is passed into it no?

As in its already possible to generate invalid configs where someone sticks model options in and then uses an attention that doesnt support those options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRA specifically doesn't have to be coupled with the attention implementation (eg. MHA). I can see why it makes sense to do it inside if other optimizations are also applied within eg. MHA though ..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JacobSzwejbka my original concern on this is that model config might be non-standardized. Now that we have Jack's new config system maybe this is less of a concern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This set up is already dependent on model config though. The lora constructors use tons of args from the config

else:
attention = cls(model_args, layer_id, rope, wq, wk, wv, wo)
transformer_block = TransformerBlock(model_args, attention)
layers.append(transformer_block)

Expand Down
48 changes: 48 additions & 0 deletions examples/models/llama/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


class LoRALinear(nn.Module):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""

def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.rank = rank
self.alpha = alpha
self.use_bias = use_bias
self.dropout = dropout

linear = nn.Linear(in_dim, out_dim, bias=use_bias)
weight = linear.weight
bias = linear.bias if self.use_bias else None
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out
10 changes: 10 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ class ModelArgs:
eos_count: int = 2

quantization_args: Optional[dict] = None
# LoRA for QAT.
lora_args: Optional[dict] = None

# LoRA arguments to set up a LoRA inference model.
# These arguments come directly from a torchtune LoRA config.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Eg. q_proj, k_proj, v_proj, output_proj
target_modules: Optional[list] = None
peft_type: Optional[str] = None # PEFT type.
base_model_name_or_path: Optional[str] = None # Base model name or path.

def __post_init__(self):
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down
Loading