Skip to content

Commit e4d98b4

Browse files
committed
add lora for mlp and unsloth
1 parent 6e0c9f6 commit e4d98b4

File tree

7 files changed

+156
-10
lines changed

7 files changed

+156
-10
lines changed

examples/models/llama/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def __init__(
416416
dropout=0.0,
417417
use_bias=args.attention_qkv_bias,
418418
)
419-
if args.target_modules is not None and "output_proj" in args.target_modules
419+
if args.target_modules is not None and ("output_proj" or "o_proj" in args.target_modules)
420420
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
421421
)
422422

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import Dict
2+
3+
import torch
4+
5+
from safetensors.torch import load_file
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
_UNSLOTH_TO_META = {
9+
"base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight",
10+
"base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight",
11+
"base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight",
12+
"base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight",
13+
"base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight",
14+
"base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight",
15+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight",
16+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight",
17+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight",
18+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight",
19+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight",
20+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight",
21+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight",
22+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight",
23+
}
24+
25+
26+
def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
27+
"""
28+
Convert a state dict from unsloth format to Meta's format. This function
29+
doesn't handle any sharding or splitting of state dicts. It follows the
30+
state_dict IN -> state_dict OUT pattern.
31+
32+
Args:
33+
state_dict (Dict[str, torch.Tensor]): State dict in unsloth format.
34+
35+
Returns:
36+
Dict[str, torch.Tensor]: State dict in Meta's format.
37+
"""
38+
converted_state_dict = {}
39+
40+
for key, value in state_dict.items():
41+
try:
42+
new_key = get_mapped_key(key, _UNSLOTH_TO_META)
43+
except Exception as e:
44+
raise ValueError(f"Key {key} not found in mapping") from e
45+
46+
converted_state_dict[new_key] = value
47+
return converted_state_dict
48+
49+
50+
def load_and_convert_unsloth_to_meta(checkpoint_path: str) -> Dict[str, torch.Tensor]:
51+
"""
52+
Load a checkpoint file and convert it to Meta's format.
53+
54+
Args:
55+
checkpoint_path (str): Path to the checkpoint file.
56+
57+
Returns:
58+
Dict[str, torch.Tensor]: State dict in Meta's format.
59+
"""
60+
state_dict = load_file(checkpoint_path)
61+
return unsloth_to_meta(state_dict)

examples/models/llama/feed_forward.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import torch.nn.functional as F
2+
3+
from executorch.examples.models.llama.lora import LoRALinear
4+
from executorch.examples.models.llama.model_args import ModelArgs
25
from torch import nn
36

47

@@ -11,3 +14,53 @@ def __init__(self, dim: int, hidden_dim: int):
1114

1215
def forward(self, x):
1316
return self.w2(F.silu(self.w1(x)) * self.w3(x))
17+
18+
19+
class LoRAFeedForward(nn.Module):
20+
def __init__(self, dim: int, hidden_dim: int, args: ModelArgs):
21+
super().__init__()
22+
23+
if (args.r is None or args.lora_alpha is None):
24+
raise ValueError("LoRA rank and alpha must be specified for LoRAFeedForward.")
25+
26+
self.w1 = (
27+
LoRALinear(
28+
in_dim=dim,
29+
out_dim=hidden_dim,
30+
rank=args.r,
31+
alpha=args.lora_alpha,
32+
dropout=0.0,
33+
use_bias=False,
34+
)
35+
if "gate_proj" in args.target_modules
36+
else nn.Linear(dim, hidden_dim, bias=False)
37+
)
38+
39+
self.w2 = (
40+
LoRALinear(
41+
in_dim=hidden_dim,
42+
out_dim=dim,
43+
rank=args.r,
44+
alpha=args.lora_alpha,
45+
dropout=0.0,
46+
use_bias=False,
47+
)
48+
if "down_proj" in args.target_modules
49+
else nn.Linear(hidden_dim, dim, bias=False)
50+
)
51+
52+
self.w3 = (
53+
LoRALinear(
54+
in_dim=dim,
55+
out_dim=hidden_dim,
56+
rank=args.r,
57+
alpha=args.lora_alpha,
58+
dropout=0.0,
59+
use_bias=False,
60+
)
61+
if "up_proj" in args.target_modules
62+
else nn.Linear(dim, hidden_dim, bias=False)
63+
)
64+
65+
def forward(self, x):
66+
return self.w2(F.silu(self.w1(x)) * self.w3(x))

examples/models/llama/install_requirements.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
# Install safetensors to load safetensors checkpoints (currently adapter only).
14+
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile safetensors
1415

1516
# Call the install helper for further setup
1617
python examples/models/llama/install_requirement_helper.py

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
AttentionSkip,
1919
ForwardOptions,
2020
)
21-
from executorch.examples.models.llama.feed_forward import FeedForward
21+
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2222
from executorch.examples.models.llama.model_args import ModelArgs
2323
from executorch.examples.models.llama.norm import RMSNorm
2424
from executorch.examples.models.llama.rope import Rope
@@ -93,6 +93,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
9393
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
9494
if args.moe:
9595
self.block_sparse_moe = MOEFeedForward(args)
96+
elif args.target_modules is not None and (
97+
"down_proj" in args.target_modules
98+
or "up_proj" in args.target_modules
99+
or "gate_proj" in args.target_modules
100+
):
101+
self.feed_forward = LoRAFeedForward(args.dim, args.hidden_dim, args)
96102
else:
97103
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
98104

examples/models/llama/model.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18+
1819
from executorch.examples.models.llama.llama_transformer import construct_transformer
20+
from executorch.examples.models.llama.lora import LoRALinear
1921
from executorch.examples.models.llama.model_args import ModelArgs
2022
from executorch.examples.models.llama.rope import Rope
2123

@@ -140,14 +142,36 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
140142
adapter_checkpoint = {}
141143
adapter_config = {}
142144
if adapter_checkpoint_path:
143-
adapter_checkpoint = torch.load(
144-
adapter_checkpoint_path, map_location=device, mmap=True
145-
)
146-
from torchtune.models import convert_weights
145+
if adapter_checkpoint_path.endswith(".pt"):
146+
adapter_checkpoint = torch.load(
147+
adapter_checkpoint_path, map_location=device, mmap=True
148+
)
149+
from torchtune.models import convert_weights
150+
151+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
152+
elif adapter_checkpoint_path.endswith(".safetensors"):
153+
from executorch.examples.models.llama.convert_weights import load_and_convert_unsloth_to_meta
154+
155+
adapter_checkpoint = load_and_convert_unsloth_to_meta(adapter_checkpoint)
156+
else:
157+
raise ValueError(
158+
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
159+
)
147160

148-
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
149161
with open(adapter_config_path, "r") as f:
150-
adapter_config = json.loads(f.read())
162+
adapter_config_full = json.loads(f.read())
163+
if (
164+
"r" not in adapter_config_full
165+
or "lora_alpha" not in adapter_config_full
166+
or "target_modules" not in adapter_config_full
167+
):
168+
raise ValueError(
169+
"Adapter config must contain r, lora_alpha, and target_modules.")
170+
adapter_config = {
171+
"r": adapter_config_full["r"],
172+
"lora_alpha": adapter_config_full["lora_alpha"],
173+
"target_modules": adapter_config_full["target_modules"],
174+
}
151175
checkpoint.update(adapter_checkpoint)
152176

153177
output_prune_map = None

examples/models/llama/model_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ class ModelArgs:
106106
# These arguments come directly from a torchtune adapter_config.json file.
107107
r: Optional[int] = None # Rank.
108108
lora_alpha: Optional[int] = None # Alpha.
109-
# Eg. q_proj, k_proj, v_proj, output_proj
109+
# Modules that we can apply lora adapters to.
110+
# Eg. q_proj, k_proj, v_proj, output_proj/o_proj, down_proj, gate_proj, up_proj
110111
target_modules: Optional[list] = None
111112
peft_type: Optional[str] = None # PEFT type.
112113
base_model_name_or_path: Optional[str] = None # Base model name or path.

0 commit comments

Comments
 (0)