Skip to content

Commit 7467108

Browse files
Lora extraction (#240)
This PR adds support for extracting PEFT compatible LoRAs out of fine tune models as demonstrated in my [LoRD](https://github.com/thomasgauthier/LoRD/) repo.
1 parent 52713a1 commit 7467108

File tree

4 files changed

+358
-0
lines changed

4 files changed

+358
-0
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ Parameters:
175175

176176
- `filter_wise`: if true, weight calculation will be per-row rather than per-tensor. Not recommended.
177177

178+
## LoRA extraction
179+
180+
Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned models.
181+
182+
### Usage:
183+
184+
```sh
185+
mergekit-extract-lora finetuned_model_id_or_path base_model_id_or_path output_path [--no-lazy-unpickle] --rank=desired_rank
186+
```
187+
178188
# Citation
179189

180190
We now have a [paper](https://arxiv.org/abs/2403.13257) you can cite for the MergeKit library:

mergekit/card.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# You should have received a copy of the GNU Lesser General Public License
1414
# along with this program. If not, see http://www.gnu.org/licenses/.
1515

16+
import logging
1617
import os
1718
from typing import Generator, List, Optional
1819

@@ -49,6 +50,26 @@
4950
```
5051
"""
5152

53+
CARD_TEMPLATE_LORA = """---
54+
{metadata}
55+
---
56+
# {name}
57+
58+
This is a LoRA extracted from a language model. It was extracted using [mergekit](https://github.com/arcee-ai/mergekit).
59+
60+
## LoRA Details
61+
62+
{details}
63+
64+
### Parameters
65+
66+
The following command was used to extract this LoRA adapter:
67+
68+
```sh
69+
{invocation}
70+
```
71+
"""
72+
5273

5374
def is_hf(path: str) -> bool:
5475
"""
@@ -175,3 +196,44 @@ def generate_card(
175196
name=name,
176197
config_yaml=config_yaml,
177198
)
199+
200+
201+
def generate_card_lora(
202+
base_model_ref: ModelReference,
203+
finetuned_model_ref: ModelReference,
204+
invocation: str,
205+
name: str,
206+
) -> str:
207+
"""
208+
Generates a markdown card for a merged model configuration.
209+
210+
Args:
211+
config: A MergeConfiguration object.
212+
config_yaml: YAML source text of the config.
213+
name: An optional name for the model.
214+
"""
215+
if not name:
216+
name = "Untitled LoRA Model (1)"
217+
218+
hf_bases = list(extract_hf_paths([base_model_ref, finetuned_model_ref]))
219+
tags = ["mergekit", "peft"]
220+
221+
details = f"This LoRA adapter was extracted from {modelref_md(finetuned_model_ref)} and uses {modelref_md(base_model_ref)} as a base."
222+
223+
if os.path.isdir(base_model_ref.model.path) or os.path.isdir(
224+
finetuned_model_ref.model.path
225+
):
226+
logging.warning(
227+
"Some model identifiers you provided are directory paths and will appear as such in the model card, you may want to edit it."
228+
)
229+
230+
return CARD_TEMPLATE_LORA.format(
231+
metadata=yaml.dump(
232+
{"base_model": hf_bases, "tags": tags, "library_name": "transformers"}
233+
),
234+
name=name,
235+
details=details,
236+
base_model=base_model_ref.model.path,
237+
finetuned_model=finetuned_model_ref.model.path,
238+
invocation=invocation,
239+
)

mergekit/scripts/extract_lora.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import json
2+
import logging
3+
import os
4+
from typing import Any, Dict, List, Optional, Tuple
5+
6+
import bitsandbytes as bnb
7+
import click
8+
import torch
9+
from peft.tuners.lora import QuantLinear
10+
from safetensors.torch import save_file
11+
from tqdm import tqdm
12+
from transformers import AutoConfig, AutoModelForCausalLM
13+
from transformers.modeling_utils import PreTrainedModel
14+
15+
from mergekit.card import generate_card_lora
16+
from mergekit.common import ModelReference
17+
from mergekit.io import LazyTensorLoader
18+
from mergekit.options import add_merge_options
19+
20+
21+
def _low_rank_decomposition(
22+
weight: torch.Tensor, reduced_rank: int = 16
23+
) -> Tuple[torch.Tensor, torch.Tensor]:
24+
"""
25+
Decompose a 2D matrix into low-rank matrices A and B using SVD.a
26+
27+
:param weight: The matrix to decompose, of shape (H, W)
28+
:param reduced_rank: The final rank of the decomposition
29+
:return: A tuple of tensors (A, B)
30+
"""
31+
if weight.dim() != 2:
32+
raise ValueError(
33+
f"Only support 2D matrix, but your input has {weight.dim()} dimensions."
34+
)
35+
36+
# SVD Decomposition
37+
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
38+
39+
# Truncated matrices
40+
A = Vh[:reduced_rank, :]
41+
B = U[:, :reduced_rank] @ torch.diag(S[:reduced_rank])
42+
43+
return A, B
44+
45+
46+
def decompose_delta_weight(
47+
new_weight: torch.Tensor,
48+
base_weight: torch.Tensor,
49+
reduced_rank: int,
50+
device: Optional[str] = None,
51+
) -> Tuple[torch.Tensor, torch.Tensor]:
52+
if device is None:
53+
device = "cuda" if torch.cuda.is_available() else "cpu"
54+
55+
new_weight = new_weight.to(device)
56+
base_weight = base_weight.to(device)
57+
58+
"""
59+
Decompose the delta weight into low-rank matrices A and B.
60+
61+
:param new_weight: The updated weight matrix after applying LoRA.
62+
:param base_weight: The original weight matrix before LoRA.
63+
:param reduced_rank: The rank for the low-rank decomposition.
64+
:param device: The device to perform computation on.
65+
:return: A tuple of tensors (A, B)
66+
"""
67+
delta_weight = new_weight - base_weight
68+
69+
max_rank = min(delta_weight.shape)
70+
assert (
71+
reduced_rank <= max_rank
72+
), f"The specified rank ({reduced_rank}) must be smaller than or equal to the rank of the weight matrices ({max_rank})"
73+
74+
A, B = _low_rank_decomposition(delta_weight, reduced_rank=reduced_rank)
75+
76+
return A, B
77+
78+
79+
def find_all_linear_names(model: PreTrainedModel) -> List[str]:
80+
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
81+
82+
names = []
83+
for name, module in model.named_modules():
84+
if (
85+
isinstance(module, cls)
86+
or "Linear" in module.__class__.__name__
87+
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
88+
):
89+
names.append(name)
90+
91+
return names
92+
93+
94+
def get_linear_module_names(model_id: str) -> List[str]:
95+
model = AutoModelForCausalLM.from_pretrained(
96+
model_id, state_dict={}, device_map="meta"
97+
) # avoid loading weights as we won't need them
98+
linear_module_names = find_all_linear_names(model)
99+
100+
return linear_module_names
101+
102+
103+
def create_peft_config(
104+
base_model_name_or_path: str, rank: int, alpha: int, target_modules: List[str]
105+
) -> Dict[str, Any]:
106+
return {
107+
"alpha_pattern": {},
108+
"auto_mapping": None,
109+
"base_model_name_or_path": base_model_name_or_path,
110+
"bias": "none",
111+
"fan_in_fan_out": False,
112+
"inference_mode": True,
113+
"init_lora_weights": True,
114+
"layers_pattern": None,
115+
"layers_to_transform": None,
116+
"loftq_config": {},
117+
"lora_alpha": alpha,
118+
"lora_dropout": 0,
119+
"megatron_config": None,
120+
"megatron_core": "megatron.core",
121+
"modules_to_save": None,
122+
"peft_type": "LORA",
123+
"r": rank,
124+
"rank_pattern": {},
125+
"revision": None,
126+
"target_modules": target_modules,
127+
"task_type": "CAUSAL_LM",
128+
"use_rslora": False,
129+
}
130+
131+
132+
def reconstruct_invocation(args):
133+
"""
134+
Reconstructs the command-line invocation string based on the given arguments stored in a dictionary.
135+
136+
Parameters:
137+
- args: A dictionary containing the command arguments with keys matching the parameter names.
138+
Expected keys are 'base_model', 'finetuned_model', 'out_path', 'no_lazy_unpickle', 'desired_rank', 'model_name' and 'device'.
139+
140+
Returns:
141+
- The reconstructed command-line invocation string.
142+
"""
143+
# Provide a default value for out_path if it's not in the dictionary
144+
out_path = args.get("out_path", "OUTPUT_PATH")
145+
146+
invocation = f"mergekit-extract-lora {args['base_model']} {args['finetuned_model']} {out_path}"
147+
if args.get("no_lazy_unpickle"):
148+
invocation += " --no-lazy-unpickle"
149+
invocation += f" --rank={args['desired_rank']}"
150+
if args.get("model_name"):
151+
invocation += f" --model_name={args['model_name']}"
152+
if args.get("device"):
153+
invocation += f" --device={args['device']}"
154+
155+
return invocation
156+
157+
158+
@click.command("mergekit-extract-lora")
159+
@click.argument("finetuned_model", type=str)
160+
@click.argument("base_model", type=str)
161+
@click.argument("out_path", type=click.Path())
162+
@click.option(
163+
"--no-lazy-unpickle",
164+
is_flag=True,
165+
help="Disable lazy unpickler (more stable, higher memory usage)",
166+
)
167+
@click.option(
168+
"--rank",
169+
"desired_rank",
170+
type=int,
171+
default=32,
172+
help="Rank for the low-rank decomposition",
173+
)
174+
@click.option(
175+
"--model_name",
176+
type=str,
177+
default=None,
178+
help="Name of the resulting model (shown in the model card)",
179+
)
180+
@click.option(
181+
"--device",
182+
type=str,
183+
default=None,
184+
help="PyTorch device to perform SVD computation on",
185+
)
186+
def main(
187+
finetuned_model: str,
188+
base_model: str,
189+
out_path: str,
190+
no_lazy_unpickle: bool,
191+
desired_rank: int,
192+
model_name: str,
193+
device: str,
194+
) -> None:
195+
"""
196+
Decomposes delta weights between a base model and a finetuned model, saving a PEFT model to the specified output path.
197+
198+
\b
199+
Arguments:
200+
FINETUNED_MODEL - the model ID or path to use as the PEFT extraction target model.
201+
BASE_MODEL - the model ID or path to use as the base model.
202+
OUT_PATH - the output path where the PEFT model will be saved.
203+
"""
204+
205+
invocation_args = {
206+
"base_model": base_model,
207+
"finetuned_model": finetuned_model,
208+
"desired_rank": desired_rank,
209+
"device": device,
210+
"out_path": out_path,
211+
"model_name": model_name,
212+
"no_lazy_unpickle": no_lazy_unpickle,
213+
}
214+
215+
os.makedirs(out_path, exist_ok=True)
216+
217+
base_model_ref = ModelReference.parse(base_model)
218+
finetuned_model_ref = ModelReference.parse(finetuned_model)
219+
220+
base_model_config = AutoConfig.from_pretrained(base_model_ref.model.path)
221+
222+
linear_module_names = get_linear_module_names(base_model_ref.model.path)
223+
finetuned_model_linear_module_names = get_linear_module_names(
224+
finetuned_model_ref.model.path
225+
)
226+
227+
assert set(linear_module_names) == set(
228+
finetuned_model_linear_module_names
229+
), "Model architecture mismatch"
230+
231+
base_loader = LazyTensorLoader(
232+
base_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle)
233+
)
234+
finetuned_loader = LazyTensorLoader(
235+
finetuned_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle)
236+
)
237+
238+
lora_weights = {}
239+
for layer_name in tqdm(linear_module_names):
240+
base_weight = base_loader.get_tensor(f"{layer_name}.weight")
241+
finetuned_weight = finetuned_loader.get_tensor(f"{layer_name}.weight")
242+
243+
lora_A, lora_B = decompose_delta_weight(
244+
finetuned_weight, base_weight, desired_rank, device=device
245+
)
246+
247+
lora_weights[f"base_model.model.{layer_name}.lora_A.weight"] = lora_A.to(
248+
"cpu"
249+
).contiguous()
250+
lora_weights[f"base_model.model.{layer_name}.lora_B.weight"] = lora_B.to(
251+
"cpu"
252+
).contiguous()
253+
254+
lora_config = create_peft_config(
255+
base_model_name_or_path=base_model_ref.model.path,
256+
alpha=desired_rank, # Setting the alpha to the reduced rank value as `peft` will scale the LoRA weights by alpha/r when applying the adapter
257+
rank=desired_rank,
258+
target_modules=list(
259+
set([module_name.split(".")[-1] for module_name in linear_module_names])
260+
),
261+
)
262+
263+
with open(os.path.join(out_path, "adapter_config.json"), "w") as f:
264+
json.dump(lora_config, f, indent=2)
265+
266+
save_file(lora_weights, os.path.join(out_path, "adapter_model.safetensors"))
267+
268+
invocation_args.pop("out_path") # don't include out_path for privacy
269+
invocation = reconstruct_invocation(invocation_args)
270+
271+
card_md = generate_card_lora(
272+
base_model_ref=base_model_ref,
273+
finetuned_model_ref=finetuned_model_ref,
274+
invocation=invocation,
275+
name=model_name,
276+
)
277+
278+
with open(os.path.join(out_path, "README.md"), "w", encoding="utf-8") as fp:
279+
fp.write(card_md)
280+
281+
logging.info(f"PEFT LoRA adapters saved to {out_path}")
282+
283+
284+
if __name__ == "__main__":
285+
main()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ mergekit-layershuffle = "mergekit.scripts.layershuffle:main"
4141
bakllama = "mergekit.scripts.bakllama:main"
4242
mergekit-moe = "mergekit.scripts.mixtral_moe:main"
4343
mergekit-tokensurgeon = "mergekit.scripts.tokensurgeon:main"
44+
mergekit-extract-lora = "mergekit.scripts.extract_lora:main"
4445

4546
[tool.setuptools]
4647
packages = [

0 commit comments

Comments
 (0)