Skip to content

Commit c755ddb

Browse files
committed
Add lora linear
1 parent 999eb7e commit c755ddb

File tree

5 files changed

+204
-27
lines changed

5 files changed

+204
-27
lines changed

examples/models/llama/attention.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ def __init__(
167167
args: ModelArgs,
168168
layer_id: int,
169169
rope: Rope,
170+
wq: nn.Module,
171+
wk: nn.Module,
172+
wv: nn.Module,
173+
wo: nn.Module,
170174
):
171175
super().__init__()
172176
self.use_kv_cache = args.use_kv_cache
@@ -190,16 +194,10 @@ def __init__(
190194
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
191195
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
192196

193-
self.wq = nn.Linear(
194-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
195-
)
196-
self.wk = nn.Linear(
197-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
198-
)
199-
self.wv = nn.Linear(
200-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
201-
)
202-
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
197+
self.wq = wq
198+
self.wk = wk
199+
self.wv = wv
200+
self.wo = wo
203201

204202
self.layer_id = layer_id
205203

examples/models/llama/export_llama_lib.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def build_args_parser() -> argparse.ArgumentParser:
209209
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
210210
)
211211

212+
parser.add_argument(
213+
"--adapter_checkpoint",
214+
required=False,
215+
help="Path to the adapter.pt file. Used if the model has trained LoRA adapters. Must provide adapter_config.",
216+
)
217+
218+
parser.add_argument(
219+
"--adapter_config",
220+
required=False,
221+
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
222+
)
223+
212224
parser.add_argument(
213225
"--use_qnn_sha",
214226
action="store_true",
@@ -592,6 +604,18 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
592604
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
593605
)
594606
params_path = canonical_path(args.params) if args.params else None
607+
608+
assert (args.adapter_checkpoint is None and args.adapter_config is None) or (
609+
args.adapter_checkpoint is not None and args.adapter_config is not None
610+
), "Must provide both adapter_checkpoint and adapter_config, or neither"
611+
612+
adapter_checkpoint_path = (
613+
canonical_path(args.adapter_checkpoint) if args.adapter_checkpoint else None
614+
)
615+
adapter_config_path = (
616+
canonical_path(args.adapter_config) if args.adapter_config else None
617+
)
618+
595619
output_dir_path = canonical_path(args.output_dir, dir=True)
596620
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
597621

@@ -603,6 +627,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
603627
checkpoint=checkpoint_path,
604628
checkpoint_dir=checkpoint_dir,
605629
params_path=params_path,
630+
adapter_checkpoint=adapter_checkpoint_path,
631+
adapter_config=adapter_config_path,
606632
use_kv_cache=args.use_kv_cache,
607633
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
608634
generate_full_logits=args.generate_full_logits,
@@ -1040,6 +1066,8 @@ def _load_llama_model(
10401066
checkpoint: Optional[str] = None,
10411067
checkpoint_dir: Optional[str] = None,
10421068
params_path: Optional[str] = None,
1069+
adapter_checkpoint: Optional[str] = None,
1070+
adapter_config: Optional[str] = None,
10431071
use_kv_cache: bool = False,
10441072
use_sdpa_with_kv_cache: bool = False,
10451073
generate_full_logits: bool = False,
@@ -1087,6 +1115,8 @@ def _load_llama_model(
10871115
checkpoint=checkpoint,
10881116
checkpoint_dir=checkpoint_dir,
10891117
params=params_path,
1118+
adapter_checkpoint=adapter_checkpoint,
1119+
adapter_config=adapter_config,
10901120
use_kv_cache=use_kv_cache,
10911121
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
10921122
generate_full_logits=generate_full_logits,

examples/models/llama/lora.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Helper functions for tranforming the model to be able to load checkpoints with
10+
# LoRA adaptors. See https://arxiv.org/abs/2106.09685 for more details about LoRA.
11+
12+
import torch
13+
from torch import nn
14+
15+
16+
class LoRALinear(nn.Module):
17+
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""
18+
19+
def __init__(
20+
self,
21+
in_dim: int,
22+
out_dim: int,
23+
rank: int,
24+
alpha: float,
25+
dropout: float = 0.0,
26+
use_bias: bool = False,
27+
):
28+
super().__init__()
29+
self.in_dim = in_dim
30+
self.out_dim = out_dim
31+
self.rank = rank
32+
self.alpha = alpha
33+
self.use_bias = use_bias
34+
self.dropout = dropout
35+
36+
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
37+
weight = linear.weight
38+
bias = linear.bias if self.use_bias else None
39+
self.register_parameter("weight", nn.Parameter(weight))
40+
self.register_parameter(
41+
"bias", nn.Parameter(bias) if bias is not None else None
42+
)
43+
44+
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
45+
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
46+
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
47+
48+
def forward(self, x: torch.Tensor) -> torch.Tensor:
49+
out = torch.nn.functional.linear(x, self.weight, self.bias)
50+
lora_out = self.lora_a(self.dropout(x))
51+
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
52+
53+
return out + lora_out

examples/models/llama/model.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
Transformer,
2121
TransformerBlock,
2222
)
23+
from executorch.examples.models.llama.lora import LoRALinear
2324
from executorch.examples.models.llama.model_args import ModelArgs
2425
from executorch.examples.models.llama.rope import Rope
2526

27+
from torchtune.models import convert_weights
28+
2629
try:
2730
from .fairseq2 import convert_to_llama_checkpoint
2831

@@ -37,6 +40,86 @@ def convert_to_llama_checkpoint(**kwargs):
3740
from ..model_base import EagerModelBase
3841

3942

43+
def construct_llm(model_args: ModelArgs) -> Transformer:
44+
if model_args.attention_type not in ATTENTION_REGISTRY:
45+
raise ValueError(
46+
f"Unknown attention type: {model_args.attention_type}. "
47+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
48+
)
49+
50+
rope = Rope(model_args)
51+
layers = torch.nn.ModuleList()
52+
cls = ATTENTION_REGISTRY[model_args.attention_type]
53+
54+
for layer_id in range(model_args.n_layers):
55+
wq = (
56+
LoRALinear(
57+
in_dim=model_args.dim,
58+
out_dim=model_args.n_heads * model_args.head_dim,
59+
rank=model_args.r, # todo
60+
alpha=model_args.lora_alpha, # todo
61+
dropout=0.0,
62+
use_bias=model_args.attention_qkv_bias,
63+
)
64+
if "q_proj" in model_args.target_modules
65+
else (
66+
torch.nn.Linear(
67+
model_args.dim,
68+
model_args.n_heads * model_args.head_dim,
69+
bias=model_args.attention_qkv_bias,
70+
)
71+
)
72+
)
73+
74+
wk = (
75+
LoRALinear(
76+
in_dim=model_args.dim,
77+
out_dim=model_args.n_kv_heads * model_args.head_dim,
78+
rank=model_args.r, # todo
79+
alpha=model_args.lora_alpha, # todo
80+
dropout=0.0,
81+
use_bias=model_args.attention_qkv_bias,
82+
)
83+
if "k_proj" in model_args.target_modules
84+
else (
85+
torch.nn.Linear(
86+
model_args.dim,
87+
model_args.n_kv_heads * model_args.head_dim,
88+
bias=model_args.attention_qkv_bias,
89+
)
90+
)
91+
)
92+
wv = (
93+
LoRALinear(
94+
in_dim=model_args.dim,
95+
out_dim=model_args.n_kv_heads * model_args.head_dim,
96+
rank=model_args.r, # todo
97+
alpha=model_args.lora_alpha, # todo
98+
dropout=0.0,
99+
use_bias=model_args.attention_qkv_bias,
100+
)
101+
if "v_proj" in model_args.target_modules
102+
else (
103+
torch.nn.Linear(
104+
model_args.dim,
105+
model_args.n_kv_heads * model_args.head_dim,
106+
bias=model_args.attention_qkv_bias,
107+
)
108+
)
109+
)
110+
111+
# todo
112+
wo = torch.nn.Linear(
113+
model_args.n_heads * model_args.head_dim, model_args.dim, bias=False
114+
)
115+
attention = cls(model_args, layer_id, rope, wq, wk, wv, wo)
116+
transformer_block = TransformerBlock(model_args, attention)
117+
layers.append(transformer_block)
118+
119+
# Construct transformer model.
120+
return Transformer(model_args, layers, rope)
121+
122+
40123
class Llama2Model(EagerModelBase):
41124
def __init__(self, **kwargs):
42125
resource_dir = get_default_model_resource_dir(__file__)
@@ -49,6 +132,10 @@ def __init__(self, **kwargs):
49132
# Params file.
50133
params_path = kwargs.get("params", None)
51134

135+
# Adapter
136+
adapter_checkpoint = kwargs.get("adapter_checkpoint", None)
137+
adapter_config = kwargs.get("adapter_config", None)
138+
52139
self.use_kv_cache = kwargs.get("use_kv_cache", False)
53140
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
54141
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -132,6 +219,22 @@ def __init__(self, **kwargs):
132219
with open(params_path, "r") as f:
133220
params = json.loads(f.read())
134221

222+
# Get adapter checkpoint and config.
223+
adapter_checkpoint = {}
224+
adapter_config = {}
225+
adapter_checkpoint_path = kwargs.get("adapter_checkpoint", None)
226+
if adapter_checkpoint_path:
227+
adapter_checkpoint = torch.load(
228+
adapter_checkpoint_path, map_location=device, mmap=True
229+
)
230+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
231+
232+
adapter_config = kwargs.get("adapter_config", None)
233+
with open(adapter_config, "r") as f:
234+
adapter_config = json.loads(f.read())
235+
236+
checkpoint.update(adapter_checkpoint)
237+
135238
output_prune_map = None
136239
if self.output_prune_map_path is not None:
137240
with open(self.output_prune_map_path, "r") as f:
@@ -156,6 +259,7 @@ def __init__(self, **kwargs):
156259
output_prune_map=output_prune_map,
157260
enable_dynamic_shape=self.enable_dynamic_shape,
158261
**params,
262+
**adapter_config,
159263
)
160264

161265
if model_args.use_scaled_rope:
@@ -177,23 +281,7 @@ def __init__(self, **kwargs):
177281
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
178282
with torch.device("meta"):
179283
# Model itself is loaded in default dtype, fp32.
180-
181-
# Construct attention layers.
182-
rope = Rope(model_args)
183-
if model_args.attention_type not in ATTENTION_REGISTRY:
184-
raise ValueError(
185-
f"Unknown attention type: {model_args.attention_type}. "
186-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
187-
)
188-
layers = torch.nn.ModuleList()
189-
cls = ATTENTION_REGISTRY[model_args.attention_type]
190-
for layer_id in range(model_args.n_layers):
191-
attention = cls(model_args, layer_id, rope)
192-
transformer_block = TransformerBlock(model_args, attention)
193-
layers.append(transformer_block)
194-
195-
# Construct transformer model.
196-
self.model_ = Transformer(model_args, layers, rope)
284+
self.model_ = construct_llm(model_args)
197285

198286
# Get checkpoint dtype.
199287
if checkpoint:

examples/models/llama/model_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,16 @@ class ModelArgs:
5353
eos_count: int = 2
5454

5555
quantization_args: Optional[dict] = None
56+
# LoRA for QAT.
5657
lora_args: Optional[dict] = None
5758

59+
# LoRA arguments.
60+
r: Optional[int] = None # Rank.
61+
lora_alpha: Optional[int] = None # Alpha.
62+
target_modules: Optional[list] = None # Target modules.
63+
peft_type: Optional[str] = None # PEFT type.
64+
base_model_name_or_path: Optional[str] = None # Base model name or path.
65+
5866
def __post_init__(self):
5967
if self.n_kv_heads is None:
6068
self.n_kv_heads = self.n_heads

0 commit comments

Comments
 (0)