Skip to content

Commit 6d4b68a

Browse files
pytorchbotlucylq
andauthored
Add LoRA linear definition (#12861)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11044 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/82/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/82/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/82/orig @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent b8fe100 commit 6d4b68a

File tree

4 files changed

+125
-8
lines changed

4 files changed

+125
-8
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ runtime.python_library(
1313
name = "llama_transformer",
1414
srcs = [
1515
"llama_transformer.py",
16+
"lora.py",
1617
"rope.py",
1718
"attention.py",
1819
"model_args.py",

examples/models/llama/attention.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8+
from executorch.examples.models.llama.lora import LoRALinear
89
from executorch.examples.models.llama.model_args import ModelArgs
910
from executorch.examples.models.llama.norm import RMSNorm
1011
from executorch.examples.models.llama.rope import Rope
@@ -325,7 +326,20 @@ def update(
325326

326327
@register_attention("mha")
327328
class AttentionMHA(Attention):
328-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
329+
def __init__(
330+
self,
331+
args: ModelArgs,
332+
layer_id: int,
333+
rope: Rope,
334+
):
335+
"""
336+
Multi-head attention layer.
337+
338+
Args:
339+
args (ModelArgs): Model configuration parameters.
340+
layer_id (int): Layer index.
341+
rope (Rope): Rotary position embedding module.
342+
"""
329343
super().__init__()
330344
self.use_kv_cache = args.use_kv_cache
331345
self.n_heads = args.n_heads
@@ -350,16 +364,60 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
350364
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
351365
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
352366

353-
self.wq = nn.Linear(
354-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
367+
self.wq = (
368+
LoRALinear(
369+
in_dim=args.dim,
370+
out_dim=args.n_heads * args.head_dim,
371+
rank=args.r,
372+
alpha=args.lora_alpha,
373+
dropout=0.0,
374+
use_bias=args.attention_qkv_bias,
375+
)
376+
if args.target_modules is not None and "q_proj" in args.target_modules
377+
else nn.Linear(
378+
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
379+
)
355380
)
356-
self.wk = nn.Linear(
357-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
381+
self.wk = (
382+
LoRALinear(
383+
in_dim=args.dim,
384+
out_dim=args.n_kv_heads * args.head_dim,
385+
rank=args.r,
386+
alpha=args.lora_alpha,
387+
dropout=0.0,
388+
use_bias=args.attention_qkv_bias,
389+
)
390+
if args.target_modules is not None and "k_proj" in args.target_modules
391+
else nn.Linear(
392+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
393+
)
358394
)
359-
self.wv = nn.Linear(
360-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
395+
self.wv = (
396+
LoRALinear(
397+
in_dim=args.dim,
398+
out_dim=args.n_kv_heads * args.head_dim,
399+
rank=args.r,
400+
alpha=args.lora_alpha,
401+
dropout=0.0,
402+
use_bias=args.attention_qkv_bias,
403+
)
404+
if args.target_modules is not None and "v_proj" in args.target_modules
405+
else nn.Linear(
406+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
407+
)
408+
)
409+
self.wo = (
410+
LoRALinear(
411+
in_dim=args.n_kv_heads * args.head_dim,
412+
out_dim=args.dim,
413+
rank=args.r,
414+
alpha=args.lora_alpha,
415+
dropout=0.0,
416+
use_bias=args.attention_qkv_bias,
417+
)
418+
if args.target_modules is not None and "output_proj" in args.target_modules
419+
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
361420
)
362-
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
363421

364422
self.layer_id = layer_id
365423

examples/models/llama/lora.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch import nn
9+
10+
11+
class LoRALinear(nn.Module):
12+
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""
13+
14+
def __init__(
15+
self,
16+
in_dim: int,
17+
out_dim: int,
18+
rank: int,
19+
alpha: float,
20+
dropout: float = 0.0,
21+
use_bias: bool = False,
22+
):
23+
super().__init__()
24+
self.in_dim = in_dim
25+
self.out_dim = out_dim
26+
self.rank = rank
27+
self.alpha = alpha
28+
self.use_bias = use_bias
29+
self.dropout = dropout
30+
31+
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
32+
weight = linear.weight
33+
bias = linear.bias if self.use_bias else None
34+
self.register_parameter("weight", nn.Parameter(weight))
35+
self.register_parameter(
36+
"bias", nn.Parameter(bias) if bias is not None else None
37+
)
38+
39+
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
40+
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
41+
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
42+
43+
def forward(self, x: torch.Tensor) -> torch.Tensor:
44+
out = torch.nn.functional.linear(x, self.weight, self.bias)
45+
lora_out = self.lora_a(self.dropout(x))
46+
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
47+
48+
return out + lora_out

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,18 @@ class ModelArgs:
5555
eos_count: int = 2
5656

5757
quantization_args: Optional[dict] = None
58+
# LoRA for QAT.
5859
lora_args: Optional[dict] = None
5960

61+
# LoRA arguments to set up a LoRA inference model.
62+
# These arguments come directly from a torchtune LoRA config.
63+
r: Optional[int] = None # Rank.
64+
lora_alpha: Optional[int] = None # Alpha.
65+
# Eg. q_proj, k_proj, v_proj, output_proj
66+
target_modules: Optional[list] = None
67+
peft_type: Optional[str] = None # PEFT type.
68+
base_model_name_or_path: Optional[str] = None # Base model name or path.
69+
6070
def __post_init__(self):
6171
if self.n_kv_heads is None:
6272
self.n_kv_heads = self.n_heads

0 commit comments

Comments
 (0)