Skip to content

Commit 09c63e6

Browse files
authored
Add Model Breadcrumbs merge method (#228)
Implements the method described in [Model Breadcrumbs: Scaling Multi-Task Model Merging with Sparse Masks](https://arxiv.org/abs/2312.06795).
1 parent 020d557 commit 09c63e6

File tree

6 files changed

+119
-12
lines changed

6 files changed

+119
-12
lines changed

README.md

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,18 @@ Several examples of merge configurations are available in [`examples/`](examples
116116

117117
A quick overview of the currently supported merge methods:
118118

119-
| Method | `merge_method` value | Multi-Model | Uses base model |
120-
| -------------------------------------------------------------------------------------------- | -------------------- | ----------- | --------------- |
121-
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` |||
122-
| SLERP | `slerp` |||
123-
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` |||
124-
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` |||
125-
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` |||
126-
| [DARE](https://arxiv.org/abs/2311.03099) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `dare_linear` |||
127-
| Passthrough | `passthrough` |||
128-
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` |||
119+
| Method | `merge_method` value | Multi-Model | Uses base model |
120+
| ------------------------------------------------------------------------------------------------ | -------------------- | ----------- | --------------- |
121+
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` |||
122+
| SLERP | `slerp` |||
123+
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` |||
124+
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` |||
125+
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` |||
126+
| [DARE](https://arxiv.org/abs/2311.03099) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `dare_linear` |||
127+
| Passthrough | `passthrough` |||
128+
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` |||
129+
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` |||
130+
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` |||
129131

130132
### Linear
131133

@@ -168,6 +170,17 @@ Parameters: same as [TIES](#ties) for `dare_ties`, or [Linear](#linear) for `dar
168170

169171
`passthrough` is a no-op that simply passes input tensors through unmodified. It is meant to be used for layer-stacking type merges where you have only one input model. Useful for frankenmerging.
170172

173+
### [Model Breadcrumbs](https://arxiv.org/abs/2312.06795)
174+
175+
An extension of task arithmetic that discards both small and and extremely large differences from the base model. As with DARE, the Model Breadcrumbs algorithm can be used with (`breadcrumbs_ties`) or without (`breadcrumbs`) the sign consensus algorithm of TIES.
176+
177+
Parameters: same as [Linear](#linear), plus:
178+
179+
- `density` - fraction of weights in differences from the base model to retain
180+
- `gamma` - fraction of largest magnitude differences to remove
181+
182+
Note that `gamma` corresponds with the parameter `β` described in the paper, while `density` is the final density of the sparsified tensors (related to `γ` and `β` by `density = 1 - γ - β`). For good default values, try `density: 0.9` and `gamma: 0.01`.
183+
171184
### [Model Stock](https://arxiv.org/abs/2403.19522)
172185

173186
Uses some neat geometric properties of fine tuned models to compute good weights for linear interpolation. Requires at least three models, including a base model.

mergekit/merge_methods/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ def get(method: str) -> MergeMethod:
6161
default_normalize=False,
6262
default_rescale=True,
6363
)
64+
elif method == "breadcrumbs":
65+
return GeneralizedTaskArithmeticMerge(
66+
consensus_method=None,
67+
sparsification_method=SparsificationMethod.magnitude_outliers,
68+
default_normalize=False,
69+
default_rescale=False,
70+
)
71+
elif method == "breadcrumbs_ties":
72+
return GeneralizedTaskArithmeticMerge(
73+
consensus_method=ConsensusMethod.sum,
74+
sparsification_method=SparsificationMethod.magnitude_outliers,
75+
default_normalize=False,
76+
default_rescale=False,
77+
)
6478
elif method == "model_stock":
6579
return ModelStockMerge()
6680
raise RuntimeError(f"Unimplemented merge method {method}")

mergekit/merge_methods/generalized_task_arithmetic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,18 @@ def parameters(self) -> List[ConfigParameterDef]:
5252
]
5353

5454
def tensor_parameters(self) -> List[ConfigParameterDef]:
55-
return [
55+
res = [
5656
ConfigParameterDef(name="weight", required=True),
5757
ConfigParameterDef(name="density", required=False, default_value=1.0),
5858
]
59+
if self.sparsification_method == SparsificationMethod.magnitude_outliers:
60+
res.append(
61+
ConfigParameterDef(
62+
name="gamma",
63+
default_value=0.01,
64+
)
65+
)
66+
return res
5967

6068
def make_task(
6169
self,
@@ -111,11 +119,16 @@ def execute(
111119
# sparsify
112120
if self.method.sparsification_method:
113121
for tv_info in tvs:
122+
kwargs = {}
123+
if "gamma" in tv_info:
124+
kwargs["gamma"] = tv_info["gamma"]
125+
114126
tv_info["delta"] = sparsify(
115127
tv_info["delta"],
116128
density=tv_info["density"],
117129
method=self.method.sparsification_method,
118130
rescale=self.rescale,
131+
**kwargs,
119132
)
120133

121134
deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)

mergekit/sparsify.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
class SparsificationMethod(str, Enum):
2222
magnitude = "magnitude"
2323
random = "random"
24+
magnitude_outliers = "magnitude_outliers"
2425

2526

2627
def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor):
@@ -41,7 +42,7 @@ def magnitude(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens
4142
if density >= 1:
4243
return tensor
4344

44-
k = int(density * tensor.view(-1).shape[0])
45+
k = int(density * tensor.numel())
4546

4647
assert k > 0, "not gonna zero out the whole tensor buddy"
4748
mask = torch.zeros_like(tensor)
@@ -59,6 +60,48 @@ def magnitude(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens
5960
return res
6061

6162

63+
def magnitude_outliers(
64+
tensor: torch.Tensor, density: float, rescale: bool, gamma: float = 0.01
65+
):
66+
"""Masks out smallest values in addition to large outliers.
67+
68+
The `gamma` proportion of the largest weights are first removed, then the
69+
smallest weights are removed to achieve the desired density.
70+
71+
Args:
72+
tensor (torch.Tensor): The tensor to sparsify.
73+
density (float): The proportion of weights to retain.
74+
gamma (float): Percent of largest weights to remove.
75+
"""
76+
if density >= 1:
77+
return tensor
78+
79+
num_elems = tensor.numel()
80+
target_n = int(density * num_elems)
81+
n_top = int(gamma * num_elems)
82+
n_bot = num_elems - target_n - n_top
83+
84+
if n_bot < 0:
85+
# cut down on the number of large weights to remove in
86+
# order to hit the target density
87+
n_top += n_bot
88+
n_bot = 0
89+
90+
w = tensor.abs().view(-1)
91+
if w.device.type == "cpu":
92+
w = w.float()
93+
indices = torch.sort(w, descending=False).indices
94+
mask = torch.zeros_like(tensor)
95+
96+
mask.view(-1)[indices[n_bot:-n_top]] = 1
97+
98+
if rescale:
99+
res = rescale_sum(tensor, mask)
100+
else:
101+
res = tensor * mask
102+
return res
103+
104+
62105
def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
63106
if density >= 1:
64107
return tensor
@@ -82,11 +125,14 @@ def sparsify(
82125
tensor: torch.Tensor,
83126
density: float,
84127
method: SparsificationMethod,
128+
gamma: float = 0,
85129
rescale: bool = False,
86130
) -> torch.Tensor:
87131
if method == SparsificationMethod.magnitude:
88132
return magnitude(tensor, density=density, rescale=rescale)
89133
elif method == SparsificationMethod.random:
90134
return bernoulli(tensor, density=density, rescale=rescale)
135+
elif method == SparsificationMethod.magnitude_outliers:
136+
return magnitude_outliers(tensor, density=density, rescale=rescale, gamma=gamma)
91137
else:
92138
raise NotImplementedError(method)

tests/test_basic_merges.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def test_task_arithmetic_merge(self, model_a, model_b, model_c):
105105
)
106106
run_and_check_merge(config)
107107

108+
def test_breadcrumbs_merge(self, model_a, model_b, model_c):
109+
config = self.two_model_config(
110+
model_a, model_b, merge_method="breadcrumbs", base_model=model_c
111+
)
112+
run_and_check_merge(config)
113+
108114
def test_ties_merge(self, model_a, model_b, model_c):
109115
config = self.two_model_config(
110116
model_a,

tests/test_sparsify.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,21 @@ def test_partial_density(self, sample_tensor):
2828
)
2929
assert torch.count_nonzero(result) == sample_tensor.view(-1).shape[0] // 2
3030

31+
def test_outliers(self, sample_tensor):
32+
for gamma_0 in [0.1, 0.2, 0.5, 1.0]:
33+
for density in [0.1, 0.3, 0.5, 0.6, 0.9, 1.0]:
34+
sparsity = 1 - density
35+
gamma = gamma_0 * sparsity
36+
result = sparsify(
37+
sample_tensor,
38+
density=density,
39+
method=SparsificationMethod.magnitude_outliers,
40+
gamma=gamma,
41+
)
42+
assert torch.count_nonzero(result) == int(
43+
sample_tensor.view(-1).shape[0] * density
44+
)
45+
3146

3247
class TestBernoulli:
3348
NUM_ITERATIONS = 1000

0 commit comments

Comments
 (0)