Skip to content

Commit 8d689e4

Browse files
authored
Fix BERT merging (#295)
Various changes to better support merging BERT based models.
1 parent ca96e86 commit 8d689e4

File tree

8 files changed

+177
-75
lines changed

8 files changed

+177
-75
lines changed

mergekit/_data/architectures/bert.json

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,101 +5,168 @@
55
],
66
"pre_weights": [
77
{
8-
"name": "bert.embeddings.position_embeddings.weight"
8+
"name": "embeddings.position_embeddings.weight",
9+
"aliases": [
10+
"bert.embeddings.position_embeddings.weight"
11+
]
912
},
1013
{
11-
"name": "bert.embeddings.token_type_embeddings.weight"
14+
"name": "embeddings.token_type_embeddings.weight",
15+
"aliases": [
16+
"bert.embeddings.token_type_embeddings.weight"
17+
]
1218
},
1319
{
14-
"name": "bert.embeddings.word_embeddings.weight",
15-
"is_embed": true
20+
"name": "embeddings.word_embeddings.weight",
21+
"is_embed": true,
22+
"aliases": [
23+
"bert.embeddings.word_embeddings.weight"
24+
]
1625
},
1726
{
18-
"name": "bert.embeddings.LayerNorm.bias",
27+
"name": "embeddings.LayerNorm.bias",
1928
"aliases": [
29+
"embeddings.LayerNorm.beta",
30+
"bert.embeddings.LayerNorm.bias",
2031
"bert.embeddings.LayerNorm.beta"
2132
]
2233
},
2334
{
24-
"name": "bert.embeddings.LayerNorm.weight",
35+
"name": "embeddings.LayerNorm.weight",
2536
"aliases": [
26-
"bert.embeddings.LayerNorm.gamma"
37+
"embeddings.LayerNorm.gamma",
38+
"bert.embeddings.LayerNorm.weight",
39+
"bert.embeddings.LayerNorm.gamma",
40+
"bert.embeddings.LayerNorm.weight"
2741
]
2842
},
2943
{
30-
"name": "bert.embeddings.position_ids",
44+
"name": "embeddings.position_ids",
3145
"optional": true,
32-
"force_dtype": "int64"
46+
"force_dtype": "int64",
47+
"aliases": [
48+
"bert.embeddings.position_ids"
49+
]
3350
}
3451
],
3552
"post_weights": [
3653
{
37-
"name": "pooler.dense.weight"
54+
"name": "pooler.dense.weight",
55+
"aliases": [
56+
"bert.pooler.dense.weight"
57+
]
3858
},
3959
{
40-
"name": "pooler.dense.bias"
60+
"name": "pooler.dense.bias",
61+
"aliases": [
62+
"bert.pooler.dense.bias"
63+
]
4164
}
4265
],
4366
"num_layers_config_key": "num_hidden_layers",
4467
"layer_templates": {
4568
"weights": [
4669
{
47-
"name": "bert.encoder.layer.${layer_index}.attention.self.query.weight"
70+
"name": "encoder.layer.${layer_index}.attention.self.query.weight",
71+
"aliases": [
72+
"bert.encoder.layer.${layer_index}.attention.self.query.weight"
73+
]
4874
},
4975
{
50-
"name": "bert.encoder.layer.${layer_index}.attention.self.query.bias"
76+
"name": "encoder.layer.${layer_index}.attention.self.query.bias",
77+
"aliases": [
78+
"bert.encoder.layer.${layer_index}.attention.self.query.bias"
79+
]
5180
},
5281
{
53-
"name": "bert.encoder.layer.${layer_index}.attention.self.key.weight"
82+
"name": "encoder.layer.${layer_index}.attention.self.key.weight",
83+
"aliases": [
84+
"bert.encoder.layer.${layer_index}.attention.self.key.weight"
85+
]
5486
},
5587
{
56-
"name": "bert.encoder.layer.${layer_index}.attention.self.key.bias"
88+
"name": "encoder.layer.${layer_index}.attention.self.key.bias",
89+
"aliases": [
90+
"bert.encoder.layer.${layer_index}.attention.self.key.bias"
91+
]
5792
},
5893
{
59-
"name": "bert.encoder.layer.${layer_index}.attention.self.value.weight"
94+
"name": "encoder.layer.${layer_index}.attention.self.value.weight",
95+
"aliases": [
96+
"bert.encoder.layer.${layer_index}.attention.self.value.weight"
97+
]
6098
},
6199
{
62-
"name": "bert.encoder.layer.${layer_index}.attention.self.value.bias"
100+
"name": "encoder.layer.${layer_index}.attention.self.value.bias",
101+
"aliases": [
102+
"bert.encoder.layer.${layer_index}.attention.self.value.bias"
103+
]
63104
},
64105
{
65-
"name": "bert.encoder.layer.${layer_index}.attention.output.dense.weight"
106+
"name": "encoder.layer.${layer_index}.attention.output.dense.weight",
107+
"aliases": [
108+
"bert.encoder.layer.${layer_index}.attention.output.dense.weight"
109+
]
66110
},
67111
{
68-
"name": "bert.encoder.layer.${layer_index}.attention.output.dense.bias"
112+
"name": "encoder.layer.${layer_index}.attention.output.dense.bias",
113+
"aliases": [
114+
"bert.encoder.layer.${layer_index}.attention.output.dense.bias"
115+
]
69116
},
70117
{
71-
"name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.bias",
118+
"name": "encoder.layer.${layer_index}.attention.output.LayerNorm.bias",
72119
"aliases": [
120+
"encoder.layer.${layer_index}.attention.output.LayerNorm.beta",
121+
"bert.encoder.layer.${layer_index}.attention.output.LayerNorm.bias",
73122
"bert.encoder.layer.${layer_index}.attention.output.LayerNorm.beta"
74123
]
75124
},
76125
{
77-
"name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.weight",
126+
"name": "encoder.layer.${layer_index}.attention.output.LayerNorm.weight",
78127
"aliases": [
128+
"encoder.layer.${layer_index}.attention.output.LayerNorm.gamma",
129+
"bert.encoder.layer.${layer_index}.attention.output.LayerNorm.weight",
79130
"bert.encoder.layer.${layer_index}.attention.output.LayerNorm.gamma"
80131
]
81132
},
82133
{
83-
"name": "bert.encoder.layer.${layer_index}.intermediate.dense.weight"
134+
"name": "encoder.layer.${layer_index}.intermediate.dense.weight",
135+
"aliases": [
136+
"bert.encoder.layer.${layer_index}.intermediate.dense.weight"
137+
]
84138
},
85139
{
86-
"name": "bert.encoder.layer.${layer_index}.intermediate.dense.bias"
140+
"name": "encoder.layer.${layer_index}.intermediate.dense.bias",
141+
"aliases": [
142+
"bert.encoder.layer.${layer_index}.intermediate.dense.bias"
143+
]
87144
},
88145
{
89-
"name": "bert.encoder.layer.${layer_index}.output.dense.weight"
146+
"name": "encoder.layer.${layer_index}.output.dense.weight",
147+
"aliases": [
148+
"bert.encoder.layer.${layer_index}.output.dense.weight"
149+
]
90150
},
91151
{
92-
"name": "bert.encoder.layer.${layer_index}.output.dense.bias"
152+
"name": "encoder.layer.${layer_index}.output.dense.bias",
153+
"aliases": [
154+
"bert.encoder.layer.${layer_index}.output.dense.bias"
155+
]
93156
},
94157
{
95-
"name": "bert.encoder.layer.${layer_index}.output.LayerNorm.bias",
158+
"name": "encoder.layer.${layer_index}.output.LayerNorm.bias",
96159
"aliases": [
160+
"encoder.layer.${layer_index}.output.LayerNorm.beta",
161+
"bert.encoder.layer.${layer_index}.output.LayerNorm.bias",
97162
"bert.encoder.layer.${layer_index}.output.LayerNorm.beta"
98163
]
99164
},
100165
{
101-
"name": "bert.encoder.layer.${layer_index}.output.LayerNorm.weight",
166+
"name": "encoder.layer.${layer_index}.output.LayerNorm.weight",
102167
"aliases": [
168+
"encoder.layer.${layer_index}.output.LayerNorm.gamma",
169+
"bert.encoder.layer.${layer_index}.output.LayerNorm.weight",
103170
"bert.encoder.layer.${layer_index}.output.LayerNorm.gamma"
104171
]
105172
}

mergekit/architecture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class WeightInfo(BaseModel, frozen=True):
5050
input_space: Optional[str] = None
5151
output_space: Optional[str] = None
5252
optional: bool = False
53-
aliases: Optional[List[str]] = None
53+
aliases: Optional[Tuple[str, ...]] = None
5454
force_dtype: Optional[str] = None
5555

5656

mergekit/common.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Dict,
2424
Generic,
2525
Iterator,
26-
List,
2726
Mapping,
2827
Optional,
2928
Tuple,
@@ -81,6 +80,7 @@ class ModelReference(BaseModel, frozen=True):
8180

8281
model: ModelPath
8382
lora: Optional[ModelPath] = None
83+
override_architecture: Optional[str] = None
8484

8585
def merged(
8686
self, cache_dir: Optional[str] = None, trust_remote_code: bool = False
@@ -122,11 +122,14 @@ def merged(
122122
return ModelReference(model=out_path)
123123

124124
def config(self, trust_remote_code: bool = False) -> PretrainedConfig:
125-
return AutoConfig.from_pretrained(
125+
res = AutoConfig.from_pretrained(
126126
self.model.path,
127127
revision=self.model.revision,
128128
trust_remote_code=trust_remote_code,
129129
)
130+
if self.override_architecture:
131+
res.architectures = [self.override_architecture]
132+
return res
130133

131134
def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex:
132135
assert self.lora is None
@@ -209,33 +212,6 @@ def dtype_from_name(name: Optional[str]) -> Optional[torch.dtype]:
209212
raise RuntimeError(f'Unimplemented dtype "{name}"')
210213

211214

212-
def rectify_embed_sizes(param_name: str, tensors: List[torch.Tensor]):
213-
# TODO: use arch_info.embed_weights() instead
214-
if ("lm_head" in param_name or "embed_tokens" in param_name) and all(
215-
len(t.shape) == 2 for t in tensors
216-
):
217-
# special case - if lm_head.weight or embed_tokens.weight have a size
218-
# mismatch, take the largest common submatrix of all of them
219-
if take_common_submatrix(tensors):
220-
logging.warning(
221-
f"Using common submatrix of size {tensors[0].shape} for {param_name}"
222-
)
223-
224-
225-
def take_common_submatrix(tensors: List[torch.Tensor]) -> bool:
226-
min_size = [None, None]
227-
for t in tensors:
228-
for idx in range(2):
229-
if min_size[idx] is None or t.shape[idx] < min_size[idx]:
230-
min_size[idx] = t.shape[idx]
231-
232-
if not all(t.shape == torch.Size(min_size) for t in tensors):
233-
for idx in range(len(tensors)):
234-
tensors[idx] = tensors[idx][: min_size[0], : min_size[1]]
235-
return True
236-
return False
237-
238-
239215
def parse_kmb(value: Union[str, int]) -> int:
240216
if isinstance(value, int):
241217
return value

mergekit/merge_methods/generalized_task_arithmetic.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ def make_task(
8181
int8_mask=parameters["int8_mask"],
8282
normalize=parameters["normalize"],
8383
rescale=parameters["rescale"],
84-
out_tensor_name=output_weight.name,
84+
weight_info=output_weight,
8585
)
8686

8787

8888
class GTATask(Task[torch.Tensor]):
8989
method: GeneralizedTaskArithmeticMerge
9090
tensors: GatherTensors
9191
base_model: ModelReference
92-
out_tensor_name: str
92+
weight_info: WeightInfo
9393
tensor_parameters: ImmutableMap[ModelReference, Any]
9494
int8_mask: bool
9595
normalize: bool
@@ -108,7 +108,7 @@ def execute(
108108
) -> torch.Tensor:
109109
# collect task vectors
110110
tvs, base = get_task_vectors(
111-
self.out_tensor_name,
111+
self.weight_info,
112112
self.base_model,
113113
tensors,
114114
tensor_parameters=self.tensor_parameters.data,
@@ -166,22 +166,24 @@ def group_label(self) -> Optional[str]:
166166

167167

168168
def get_task_vectors(
169-
parameter_name: str,
169+
weight_info: WeightInfo,
170170
base_model: ModelReference,
171171
tensors: ImmutableMap[ModelReference, torch.Tensor],
172172
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
173173
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
174174
keys = list(tensors.keys())
175175
base = tensors[base_model]
176176

177+
parameter_name = weight_info.name
178+
177179
res = []
178180
for model in keys:
179181
if model == base_model:
180182
continue
181183

182184
x = tensors[model].to(base.dtype)
183185
if x.shape != base.shape:
184-
if "lm_head" in parameter_name or "embed_tokens" in parameter_name:
186+
if weight_info.is_embed:
185187
x = x[: base.shape[0], : base.shape[1]]
186188
logging.warning(f"Using submatrix of {model}:{parameter_name}")
187189
else:

mergekit/merge_methods/linear.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@
1818
import torch
1919

2020
from mergekit.architecture import WeightInfo
21-
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
21+
from mergekit.common import ImmutableMap, ModelReference
2222
from mergekit.graph import Task
2323
from mergekit.io.tasks import GatherTensors
2424
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
25+
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes
2526

2627

2728
class LinearMergeTask(Task[torch.Tensor]):
2829
gather_tensors: GatherTensors
2930
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
3031
normalize: bool
31-
parameter_name: str
32+
weight_info: WeightInfo
3233

3334
def uses_accelerator(self) -> bool:
3435
return True
@@ -44,12 +45,12 @@ def execute(
4445
tensors = [tensors[key] for key in keys]
4546
weights = [self.tensor_parameters[key]["weight"] for key in keys]
4647

47-
rectify_embed_sizes(self.parameter_name, tensors)
48+
rectify_embed_sizes(self.weight_info, tensors)
4849

4950
unique_shapes = set(t.shape for t in tensors)
5051
if len(unique_shapes) != 1:
5152
raise RuntimeError(
52-
f"Tensor size mismatch for {self.parameter_name}, sizes: {list(unique_shapes)}"
53+
f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}"
5354
)
5455

5556
tensors = torch.stack(tensors, dim=0)
@@ -89,5 +90,5 @@ def make_task(
8990
gather_tensors=tensors,
9091
tensor_parameters=tensor_parameters,
9192
normalize=parameters["normalize"],
92-
parameter_name=output_weight.name,
93+
weight_info=output_weight,
9394
)

0 commit comments

Comments
 (0)