-
Notifications
You must be signed in to change notification settings - Fork 689
Add LoRA linear definition #11044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA linear definition #11044
Changes from 2 commits
287e1dd
0651a0e
0edac47
8a2f4cf
5bc1e69
db09a39
78c5d29
e2dcd8e
5f9ae15
f40ca33
8bcd073
180fe7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
ForwardOptions, | ||
) | ||
|
||
from executorch.examples.models.llama.lora import LoRALinear | ||
from executorch.examples.models.llama.model_args import ModelArgs | ||
from executorch.examples.models.llama.norm import RMSNorm | ||
from executorch.examples.models.llama.rope import Rope | ||
|
@@ -255,7 +256,67 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: | |
layers = torch.nn.ModuleList() | ||
cls = ATTENTION_REGISTRY[model_args.attention_type] | ||
for layer_id in range(model_args.n_layers): | ||
attention = cls(model_args, layer_id, rope) | ||
wq = ( | ||
LoRALinear( | ||
in_dim=model_args.dim, | ||
out_dim=model_args.n_heads * model_args.head_dim, | ||
rank=model_args.r, | ||
alpha=model_args.lora_alpha, | ||
dropout=0.0, | ||
use_bias=model_args.attention_qkv_bias, | ||
) | ||
if model_args.target_modules is not None | ||
and "q_proj" in model_args.target_modules | ||
else None | ||
) | ||
|
||
wk = ( | ||
LoRALinear( | ||
in_dim=model_args.dim, | ||
out_dim=model_args.n_kv_heads * model_args.head_dim, | ||
rank=model_args.r, | ||
alpha=model_args.lora_alpha, | ||
dropout=0.0, | ||
use_bias=model_args.attention_qkv_bias, | ||
) | ||
if model_args.target_modules is not None | ||
and "k_proj" in model_args.target_modules | ||
else None | ||
) | ||
|
||
wv = ( | ||
LoRALinear( | ||
in_dim=model_args.dim, | ||
out_dim=model_args.n_kv_heads * model_args.head_dim, | ||
rank=model_args.r, | ||
alpha=model_args.lora_alpha, | ||
dropout=0.0, | ||
use_bias=model_args.attention_qkv_bias, | ||
) | ||
if model_args.target_modules is not None | ||
else None | ||
) | ||
|
||
wo = ( | ||
LoRALinear( | ||
in_dim=model_args.n_kv_heads * model_args.head_dim, | ||
out_dim=model_args.dim, | ||
rank=model_args.r, | ||
alpha=model_args.lora_alpha, | ||
dropout=0.0, | ||
use_bias=model_args.attention_qkv_bias, | ||
) | ||
if model_args.target_modules is not None | ||
and "output_proj" in model_args.target_modules | ||
else None | ||
) | ||
if model_args.attention_type == "static": | ||
# Static attention constructs ModuleLists for qkvo and | ||
# populates them with MHA attention layers; do not pass in | ||
# wq, wk, wv, wo. | ||
attention = cls(model_args, layer_id, rope) | ||
|
||
else: | ||
attention = cls(model_args, layer_id, rope, wq, wk, wv, wo) | ||
transformer_block = TransformerBlock(model_args, attention) | ||
layers.append(transformer_block) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class LoRALinear(nn.Module): | ||
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`.""" | ||
|
||
def __init__( | ||
self, | ||
in_dim: int, | ||
out_dim: int, | ||
rank: int, | ||
alpha: float, | ||
dropout: float = 0.0, | ||
use_bias: bool = False, | ||
): | ||
super().__init__() | ||
self.in_dim = in_dim | ||
self.out_dim = out_dim | ||
self.rank = rank | ||
self.alpha = alpha | ||
self.use_bias = use_bias | ||
self.dropout = dropout | ||
|
||
linear = nn.Linear(in_dim, out_dim, bias=use_bias) | ||
weight = linear.weight | ||
bias = linear.bias if self.use_bias else None | ||
self.register_parameter("weight", nn.Parameter(weight)) | ||
self.register_parameter( | ||
"bias", nn.Parameter(bias) if bias is not None else None | ||
) | ||
|
||
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() | ||
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) | ||
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
out = torch.nn.functional.linear(x, self.weight, self.bias) | ||
lora_out = self.lora_a(self.dropout(x)) | ||
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) | ||
|
||
return out + lora_out |
Uh oh!
There was an error while loading. Please reload this page.