Skip to content

Commit 262668d

Browse files
committed
Add LoRA linear definition
Pull Request resolved: #11044 ^ Add lora linear definition. Pull out linears from attention, and allow custom linear (eg. lora linear) to be passed in. If none, construct linear (current behaviour). ghstack-source-id: 298178479 @exported-using-ghexport Differential Revision: [D73953776](https://our.internmc.facebook.com/intern/diff/D73953776/)
1 parent b183830 commit 262668d

File tree

5 files changed

+129
-8
lines changed

5 files changed

+129
-8
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ runtime.python_library(
1313
name = "llama_transformer",
1414
srcs = [
1515
"llama_transformer.py",
16+
"lora.py",
1617
"rope.py",
1718
"attention.py",
1819
"model_args.py",

examples/models/llama/attention.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,20 @@ def update(
325325

326326
@register_attention("mha")
327327
class AttentionMHA(Attention):
328-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
328+
def __init__(
329+
self,
330+
args: ModelArgs,
331+
layer_id: int,
332+
rope: Rope,
333+
):
334+
"""
335+
Multi-head attention layer.
336+
337+
Args:
338+
args (ModelArgs): Model configuration parameters.
339+
layer_id (int): Layer index.
340+
rope (Rope): Rotary position embedding module.
341+
"""
329342
super().__init__()
330343
self.use_kv_cache = args.use_kv_cache
331344
self.n_heads = args.n_heads
@@ -350,16 +363,64 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
350363
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
351364
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
352365

353-
self.wq = nn.Linear(
354-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
366+
self.wq = (
367+
LoRALinear(
368+
in_dim=args.dim,
369+
out_dim=args.n_heads * args.head_dim,
370+
rank=args.r,
371+
alpha=args.lora_alpha,
372+
dropout=0.0,
373+
use_bias=args.attention_qkv_bias,
374+
)
375+
if args.target_modules is not None
376+
and "q_proj" in args.target_modules
377+
else nn.Linear(
378+
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
379+
)
355380
)
356-
self.wk = nn.Linear(
357-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
381+
self.wk = (
382+
LoRALinear(
383+
in_dim=args.dim,
384+
out_dim=args.n_kv_heads * args.head_dim,
385+
rank=args.r,
386+
alpha=args.lora_alpha,
387+
dropout=0.0,
388+
use_bias=args.attention_qkv_bias,
389+
)
390+
if args.target_modules is not None
391+
and "k_proj" in args.target_modules
392+
else nn.Linear(
393+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
394+
)
358395
)
359-
self.wv = nn.Linear(
360-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
396+
self.wv = (
397+
LoRALinear(
398+
in_dim=args.dim,
399+
out_dim=args.n_kv_heads * args.head_dim,
400+
rank=args.r,
401+
alpha=args.lora_alpha,
402+
dropout=0.0,
403+
use_bias=args.attention_qkv_bias,
404+
)
405+
if args.target_modules is not None
406+
and "v_proj" in args.target_modules
407+
else nn.Linear(
408+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
409+
)
410+
)
411+
self.wo = (
412+
LoRALinear(
413+
in_dim=args.n_kv_heads * args.head_dim,
414+
out_dim=args.dim,
415+
rank=args.r,
416+
alpha=args.lora_alpha,
417+
dropout=0.0,
418+
use_bias=args.attention_qkv_bias,
419+
)
420+
if args.target_modules is not None
421+
and "output_proj" in args.target_modules
422+
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
361423
)
362-
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
363424

364425
self.layer_id = layer_id
365426

examples/models/llama/llama_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ForwardOptions,
1919
)
2020

21+
from executorch.examples.models.llama.lora import LoRALinear
2122
from executorch.examples.models.llama.model_args import ModelArgs
2223
from executorch.examples.models.llama.norm import RMSNorm
2324
from executorch.examples.models.llama.rope import Rope

examples/models/llama/lora.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch import nn
9+
10+
11+
class LoRALinear(nn.Module):
12+
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""
13+
14+
def __init__(
15+
self,
16+
in_dim: int,
17+
out_dim: int,
18+
rank: int,
19+
alpha: float,
20+
dropout: float = 0.0,
21+
use_bias: bool = False,
22+
):
23+
super().__init__()
24+
self.in_dim = in_dim
25+
self.out_dim = out_dim
26+
self.rank = rank
27+
self.alpha = alpha
28+
self.use_bias = use_bias
29+
self.dropout = dropout
30+
31+
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
32+
weight = linear.weight
33+
bias = linear.bias if self.use_bias else None
34+
self.register_parameter("weight", nn.Parameter(weight))
35+
self.register_parameter(
36+
"bias", nn.Parameter(bias) if bias is not None else None
37+
)
38+
39+
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
40+
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
41+
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
42+
43+
def forward(self, x: torch.Tensor) -> torch.Tensor:
44+
out = torch.nn.functional.linear(x, self.weight, self.bias)
45+
lora_out = self.lora_a(self.dropout(x))
46+
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
47+
48+
return out + lora_out

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,18 @@ class ModelArgs:
5555
eos_count: int = 2
5656

5757
quantization_args: Optional[dict] = None
58+
# LoRA for QAT.
5859
lora_args: Optional[dict] = None
5960

61+
# LoRA arguments to set up a LoRA inference model.
62+
# These arguments come directly from a torchtune LoRA config.
63+
r: Optional[int] = None # Rank.
64+
lora_alpha: Optional[int] = None # Alpha.
65+
# Eg. q_proj, k_proj, v_proj, output_proj
66+
target_modules: Optional[list] = None
67+
peft_type: Optional[str] = None # PEFT type.
68+
base_model_name_or_path: Optional[str] = None # Base model name or path.
69+
6070
def __post_init__(self):
6171
if self.n_kv_heads is None:
6272
self.n_kv_heads = self.n_heads

0 commit comments

Comments
 (0)