Skip to content

Commit bb4fb50

Browse files
authored
FEAT Add MiSS as a replacement for Bone. (#2604)
Add MiSS, an evolution of Bone, from https://arxiv.org/abs/2409.15371. MiSS will replace Bone, which is now deprecated. A script to convert Bone checkpoints to MiSS checkpoints is included.
1 parent a91ec33 commit bb4fb50

File tree

21 files changed

+1412
-11
lines changed

21 files changed

+1412
-11
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@
130130
title: SHiRA
131131
- local: package_reference/c3a
132132
title: C3A
133+
- local: package_reference/miss
134+
title: MiSS
133135

134136
title: Adapters
135137
- sections:

docs/source/conceptual_guides/adapter.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,16 @@ HRA constructs a chain of `r` trainable Householder reflections (HRs). Because t
122122
The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer.
123123

124124
## Bone
125-
[DiSHA](https://huggingface.co/papers/2409.15371) A novel PEFT technique distinct from LoRA, called Dimension-Sharding Adaptation (DiSHA). By dividing the original weights into multiple subspaces that share a single matrix for weight updates, DiSHA simplifies the process by requiring the trainable matrix to be initialized to zero, eliminating the need for complex initialization as in some LoRA variants. Bone and Bat are derivative structures of DiSHA. Bone significantly improves computational efficiency while saving memory, whereas Bat addresses the limitation of Bone's linear update by employing a non-linear update to break through the upper bound.
125+
[MiSS](https://huggingface.co/papers/2409.15371) New version of paper(MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing)
126+
If you already have a Bone checkpoint, you can use `/scripts/convert-bone-to-miss.py` to convert it into a MiSS checkpoint and proceed with training using MiSS.
126127

127-
<small><a href="https://huggingface.co/papers/2409.15371">DiSHA: Dimension-Sharding Adaptation with Fast Convergence and Fast Computation</a></small>
128+
## MiSS
129+
[MiSS](https://huggingface.co/papers/2409.15371) MiSS (Matrix Shard Sharing) is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.)
128130

129-
Intuitively, the shape of a single trainable matrix in Bone is consistent with `lora_B`, so the `r` parameter in Bone is less than the `r` in LoRA by (`in_feature * r`).
131+
<small><a href="https://huggingface.co/papers/2409.15371">MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing</a></small>
130132

131-
Note: Bat's r (b) is special and requires that weight W satisfies the conditions `in_features % r == 0` and `out_features % r == 0`. Additionally, when `in_features == out_features` and Bone-r equals LoRA-r, Bone's number of trainable parameters is only half that of LoRA.
133+
Intuitively, the shape of a single trainable matrix in MiSS is consistent with `lora_B`, so the `r` parameter in MiSS is less than the `r` in LoRA by (`in_feature * r`).
132134

133-
Although the nonlinear updates of Bat bring some performance improvements, they also increase computational overhead. Its main purpose is to provide researchers with a direction for improvement. Therefore, we recommend fine-tuning the comprehensive Bone model instead.
135+
Note: Bat's r (b) is special and requires that weight W satisfies the conditions `in_features % r == 0` and `out_features % r == 0`. Additionally, when `in_features == out_features` and MiSS-r equals LoRA-r, MiSS's number of trainable parameters is only half that of LoRA.
136+
137+
Although the nonlinear updates of Bat bring some performance improvements, they also increase computational overhead. Its main purpose is to provide researchers with a direction for improvement. Therefore, we recommend fine-tuning the comprehensive MiSS model instead.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# MiSS
18+
19+
MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing([MiSS](https://huggingface.co/papers/2409.15371)) is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.
20+
21+
The abstract from the paper is:
22+
23+
*Parameter-Efficient Fine-Tuning (PEFT) methods, particularly Low-Rank Adaptation (LoRA), effectively reduce the number of trainable parameters in Large Language Models (LLMs). However, as model scales continue to grow, the demand for computational resources remains a significant challenge. Existing LoRA variants often struggle to strike an optimal balance between adaptability (model performance and convergence speed) and efficiency (computational overhead, memory usage, and initialization time). This paper introduces MiSS(Matrix Shard Sharing ), a novel PEFT approach that addresses this trade-off through a simple shard-sharing mechanism. MiSS leverages the insight that a low-rank adaptation can be achieved by decomposing the weight matrix into multiple fragment matrices and utilizing a shared, trainable common fragment. This method constructs the low-rank update matrix through the replication of these shared, partitioned shards. We also propose a hardware-efficient and broadly applicable implementation for MiSS. Extensive experiments conducted on a range of tasks, alongside a systematic analysis of computational performance, demonstrate MiSS's superiority. The results show that MiSS significantly outperforms standard LoRA and its prominent variants in both model performance metrics and computational efficiency, including initialization speed and training throughput. By effectively balancing expressive power and resource utilization, MiSS offers a compelling solution for efficiently adapting large-scale models*.
24+
25+
26+
## MissConfig
27+
28+
[[autodoc]] tuners.miss.config.MissConfig
29+
30+
## MissModel
31+
32+
[[autodoc]] tuners.miss.model.MissModel

examples/miss_finetuning/README.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing
2+
## Introduction ([Paper](https://huggingface.co/papers/2409.15371), [code](https://github.com/JL-er/MiSS))
3+
MiSS (Matrix Shard Sharing) is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.
4+
5+
6+
## Quick Start
7+
```python
8+
import torch
9+
from peft import MissConfig, get_peft_model
10+
from transformers import AutoTokenizer, AutoModelForCausalLM
11+
from trl import SFTConfig, SFTTrainer
12+
from datasets import load_dataset
13+
14+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
15+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
16+
tokenizer.pad_token_id = tokenizer.eos_token_id
17+
18+
miss_config = MissConfig(
19+
r = 64
20+
)
21+
#bat: In this mode, you can enable nonlinear updates across different shards.
22+
# miss_config = MissConfig(
23+
# r = 64,
24+
# init_weights="bat"
25+
# )
26+
27+
# mini: In this mode, you can set a smaller rank to use fewer trainable parameters, but it is recommended to keep `out_features % mini_r == 0`.
28+
# miss_config = MissConfig(
29+
# r = 64,
30+
# init_weights="mini",
31+
# mini_r = 8
32+
# )
33+
peft_model = get_peft_model(model, miss_config)
34+
35+
peft_model.print_trainable_parameters()
36+
37+
dataset = load_dataset("imdb", split="train[:1%]")
38+
39+
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
40+
trainer = SFTTrainer(
41+
model=peft_model,
42+
args=training_args,
43+
train_dataset=dataset,
44+
tokenizer=tokenizer,
45+
)
46+
trainer.train()
47+
peft_model.save_pretrained("miss-llama-2-7b")
48+
```
49+
50+
51+
To utilize the fine-tuned MiSS modules, simply run the following command:
52+
```python
53+
import torch
54+
from peft import PeftModel
55+
from transformers import AutoModelForCausalLM
56+
57+
model = AutoModelForCausalLM.from_pretrained(
58+
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
59+
)
60+
peft_model = PeftModel.from_pretrained(model, "miss-llama-2-7b")
61+
```
62+
63+
## Advanced Usage
64+
65+
### Fine-tune
66+
```shell
67+
#Bat performs better than MiSS, but it uses more memory and is twice as slow. If you want to use the Bat method, you only need to add the parameter init_weights="bat".
68+
python miss_finetuning.py \
69+
--base_model_name_or_path meta-llama/Llama-2-7b-hf \
70+
--output_dir output/miss-llama-2-7b-metamath-10k \
71+
--miss_r 64 \
72+
--init_weights True \
73+
--bits bf16 \
74+
--data_path meta-math/MetaMathQA \
75+
--dataset_split train[:100000] \
76+
--dataset_field query response \
77+
--bf16 True \
78+
--num_train_epochs 1 \
79+
--per_device_train_batch_size 2 \
80+
--gradient_accumulation_steps 8 \
81+
--save_strategy "steps" \
82+
--save_steps 1000 \
83+
--save_total_limit 1 \
84+
--logging_steps 1 \
85+
--learning_rate 2e-5 \
86+
--weight_decay 0. \
87+
--warmup_ratio 0.03 \
88+
--tf32 True \
89+
--report_to none
90+
```
91+
92+
93+
94+
# Citation
95+
```bib
96+
@misc{kang2025balancingloraperformanceefficiency,
97+
title={Balancing LoRA Performance and Efficiency with Simple Shard Sharing},
98+
author={Jiale Kang and Qingyu Yin},
99+
year={2025},
100+
eprint={2409.15371},
101+
archivePrefix={arXiv},
102+
primaryClass={cs.CL},
103+
url={https://arxiv.org/abs/2409.15371},
104+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2025-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from dataclasses import dataclass, field
17+
from typing import Literal, Optional
18+
19+
import torch
20+
from datasets import load_dataset
21+
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
22+
from trl import SFTConfig, SFTTrainer
23+
24+
from peft import MissConfig, get_peft_model
25+
26+
27+
@dataclass
28+
class ScriptArguments(SFTConfig):
29+
# model configs
30+
base_model_name_or_path: Optional[str] = field(
31+
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
32+
)
33+
bits: str = field(default="bf16", metadata={"help": "(`['bf16', 'fp16', fp32]`)"})
34+
init_weights: Literal[True, "bat"] = field(
35+
default=True,
36+
metadata={
37+
"help": (
38+
"True -> MiSS efficience and balance; `bat` -> Bat, `mini` -> smaller MiSS efficience and balance"
39+
),
40+
},
41+
)
42+
miss_r: int = field(default=16)
43+
merge_and_save: bool = field(default=False)
44+
# dataset configs
45+
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
46+
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
47+
dataset_field: list[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
48+
49+
50+
parser = HfArgumentParser(ScriptArguments)
51+
script_args = parser.parse_args_into_dataclasses()[0]
52+
print(script_args)
53+
54+
print(f"Load pre-processed residual model in {script_args.bits} bits.")
55+
if script_args.bits in ["nf4", "fp4", "int8"]:
56+
print("MiSS currently does not support quantization.")
57+
58+
elif script_args.base_model_name_or_path is not None:
59+
print(f"No available pre-processed model, manually initialize a MiSS using {script_args.base_model_name_or_path}.")
60+
model = AutoModelForCausalLM.from_pretrained(
61+
script_args.base_model_name_or_path,
62+
torch_dtype=(
63+
torch.float16
64+
if script_args.bits == "fp16"
65+
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
66+
),
67+
device_map="auto",
68+
)
69+
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path)
70+
tokenizer.pad_token_id = tokenizer.eos_token_id
71+
miss_config = MissConfig(
72+
r=script_args.miss_r,
73+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
74+
bias="none",
75+
task_type="CAUSAL_LM",
76+
init_weights=script_args.init_weights,
77+
)
78+
peft_model = get_peft_model(model, miss_config)
79+
80+
print(peft_model)
81+
peft_model.print_trainable_parameters()
82+
83+
print(f"Training MiSS with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.")
84+
dataset = load_dataset(script_args.data_path, split=script_args.dataset_split)
85+
dataset = dataset.map(
86+
lambda example: {
87+
"text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}"
88+
}
89+
)
90+
91+
trainer = SFTTrainer(
92+
model=peft_model,
93+
args=script_args,
94+
train_dataset=dataset,
95+
tokenizer=tokenizer,
96+
)
97+
trainer.train()
98+
trainer.save_state()
99+
100+
peft_model.save_pretrained(
101+
os.path.join(script_args.output_dir, "miss_ft"),
102+
)
103+
104+
if script_args.merge_and_save:
105+
model = peft_model.merge_and_unload()
106+
model.save_pretrained(os.path.join(script_args.output_dir, "miss_merged"))
107+
tokenizer.save_pretrained(os.path.join(script_args.output_dir, "miss_merged"))

scripts/convert-bone-to-miss.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2025 Your Organization/Project. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Convert Bone checkpoint to MiSS format."""
17+
18+
import argparse
19+
import json
20+
import os
21+
from pathlib import Path
22+
23+
from safetensors import safe_open
24+
from safetensors.torch import save_file
25+
26+
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
27+
28+
29+
def convert_bone_to_miss(bone_dir: Path, miss_dir: Path) -> None:
30+
"""Convert Bone checkpoint files to MiSS format."""
31+
bone_config_path = bone_dir / CONFIG_NAME
32+
miss_config_path = miss_dir / CONFIG_NAME
33+
if not os.path.exists(miss_dir):
34+
os.makedirs(miss_dir, exist_ok=True)
35+
with open(bone_config_path, encoding="utf-8") as f:
36+
config = json.load(f)
37+
38+
config["peft_type"] = "MISS"
39+
40+
with open(miss_config_path, "w", encoding="utf-8") as f:
41+
json.dump(config, f, indent=2, ensure_ascii=False)
42+
43+
bone_weight_path = bone_dir / SAFETENSORS_WEIGHTS_NAME
44+
miss_weight_path = miss_dir / SAFETENSORS_WEIGHTS_NAME
45+
46+
new_data = {}
47+
48+
with safe_open(bone_weight_path, framework="pt") as f:
49+
for old_key in f.keys():
50+
tensor = f.get_tensor(old_key)
51+
new_key = old_key.replace(".bone_", ".miss_")
52+
new_data[new_key] = tensor
53+
54+
save_file(new_data, miss_weight_path)
55+
56+
print(f"Converted checkpoint saved at {miss_weight_path}")
57+
58+
59+
def main() -> None:
60+
parser = argparse.ArgumentParser(description="Convert Bone checkpoint to MiSS format.")
61+
parser.add_argument("bone_dir", type=Path, help="Directory containing Bone checkpoint files")
62+
parser.add_argument("miss_dir", type=Path, help="Directory to save MiSS checkpoint files")
63+
args = parser.parse_args()
64+
65+
args.miss_dir.mkdir(parents=True, exist_ok=True)
66+
convert_bone_to_miss(args.bone_dir, args.miss_dir)
67+
68+
69+
if __name__ == "__main__":
70+
main()

src/peft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
LoraConfig,
7676
LoraModel,
7777
LoraRuntimeConfig,
78+
MissConfig,
79+
MissModel,
7880
MultitaskPromptTuningConfig,
7981
MultitaskPromptTuningInit,
8082
OFTConfig,
@@ -161,6 +163,8 @@
161163
"LoraConfig",
162164
"LoraModel",
163165
"LoraRuntimeConfig",
166+
"MissConfig",
167+
"MissModel",
164168
"MultitaskPromptTuningConfig",
165169
"MultitaskPromptTuningInit",
166170
"OFTConfig",

src/peft/tuners/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_eva_state_dict,
3434
initialize_lora_eva_weights,
3535
)
36+
from .miss import MissConfig, MissModel
3637
from .mixed import MixedModel
3738
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
3839
from .oft import OFTConfig, OFTModel
@@ -78,6 +79,8 @@
7879
"LoraConfig",
7980
"LoraModel",
8081
"LoraRuntimeConfig",
82+
"MissConfig",
83+
"MissModel",
8184
"MixedModel",
8285
"MultitaskPromptEmbedding",
8386
"MultitaskPromptTuningConfig",

0 commit comments

Comments
 (0)