Skip to content

Commit 41e6d4a

Browse files
authored
Supported Architectures – code artifact cleanup (#1136)
* Removed all attributes of which directly mapped keys. These attributes are now handled by the component mapping Bridge classes * Formatting update * Removed additional missed key
1 parent 36f1c0a commit 41e6d4a

File tree

21 files changed

+3
-275
lines changed

21 files changed

+3
-275
lines changed

transformer_lens/model_bridge/supported_architectures/bert.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,6 @@ def __init__(self, cfg: Any) -> None:
4141
self.cfg.attn_only = False
4242

4343
self.weight_processing_conversions = {
44-
"embed.e": "bert.embeddings.word_embeddings.weight",
45-
"pos_embed.pos": "bert.embeddings.position_embeddings.weight",
46-
"embed.token_type_embeddings": "bert.embeddings.token_type_embeddings.weight",
47-
"embed.LayerNorm.weight": "bert.embeddings.LayerNorm.weight",
48-
"embed.LayerNorm.bias": "bert.embeddings.LayerNorm.bias",
49-
"blocks.{i}.ln1.w": "bert.encoder.layer.{i}.attention.output.LayerNorm.weight",
50-
"blocks.{i}.ln1.b": "bert.encoder.layer.{i}.attention.output.LayerNorm.bias",
51-
"blocks.{i}.ln2.w": "bert.encoder.layer.{i}.output.LayerNorm.weight",
52-
"blocks.{i}.ln2.b": "bert.encoder.layer.{i}.output.LayerNorm.bias",
5344
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
5445
tensor_conversion=RearrangeTensorConversion(
5546
"(h d_head) d_model -> h d_head d_model"
@@ -86,19 +77,6 @@ def __init__(self, cfg: Any) -> None:
8677
),
8778
source_key="bert.encoder.layer.{i}.attention.output.dense.weight",
8879
),
89-
"blocks.{i}.attn.o.bias": "bert.encoder.layer.{i}.attention.output.dense.bias",
90-
"blocks.{i}.mlp.in": "bert.encoder.layer.{i}.intermediate.dense.weight",
91-
"blocks.{i}.mlp.b_in": "bert.encoder.layer.{i}.intermediate.dense.bias",
92-
"blocks.{i}.mlp.out": "bert.encoder.layer.{i}.output.dense.weight",
93-
"blocks.{i}.mlp.b_out": "bert.encoder.layer.{i}.output.dense.bias",
94-
"ln_final.w": "bert.pooler.dense.weight",
95-
"ln_final.b": "bert.pooler.dense.bias",
96-
"unembed.u": "cls.predictions.transform.dense.weight",
97-
"unembed.b_U": "cls.predictions.transform.dense.bias",
98-
"unembed.LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight",
99-
"unembed.LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias",
100-
"unembed.decoder.weight": "cls.predictions.decoder.weight",
101-
"unembed.decoder.bias": "cls.predictions.bias",
10280
}
10381

10482
# Set up component mapping

transformer_lens/model_bridge/supported_architectures/bloom.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ def __init__(self, cfg: Any) -> None:
3636

3737
self.cfg.default_prepend_bos = False
3838
self.weight_processing_conversions = {
39-
"embed.e": "transformer.word_embeddings.weight",
40-
"blocks.{i}.ln1.w": "transformer.h.{i}.input_layernorm.weight",
41-
"blocks.{i}.ln1.b": "transformer.h.{i}.input_layernorm.bias",
4239
"blocks.{i}.attn.q": ParamProcessingConversion(
4340
tensor_conversion=RearrangeTensorConversion(
4441
"(three n h) m -> three n m h",
@@ -67,19 +64,6 @@ def __init__(self, cfg: Any) -> None:
6764
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
6865
source_key="transformer.h.{i}.self_attention.dense.weight",
6966
),
70-
"blocks.{i}.attn.b_Q": "transformer.h.{i}.self_attention.query_key_value.bias",
71-
"blocks.{i}.attn.b_K": "transformer.h.{i}.self_attention.query_key_value.bias",
72-
"blocks.{i}.attn.b_V": "transformer.h.{i}.self_attention.query_key_value.bias",
73-
"blocks.{i}.attn.b_O": "transformer.h.{i}.self_attention.dense.bias",
74-
"blocks.{i}.ln2.w": "transformer.h.{i}.post_attention_layernorm.weight",
75-
"blocks.{i}.ln2.b": "transformer.h.{i}.post_attention_layernorm.bias",
76-
"blocks.{i}.mlp.in": "transformer.h.{i}.mlp.dense_h_to_4h.weight",
77-
"blocks.{i}.mlp.b_in": "transformer.h.{i}.mlp.dense_h_to_4h.bias",
78-
"blocks.{i}.mlp.out": "transformer.h.{i}.mlp.dense_4h_to_h.weight",
79-
"blocks.{i}.mlp.b_out": "transformer.h.{i}.mlp.dense_4h_to_h.bias",
80-
"ln_final.w": "transformer.ln_f.weight",
81-
"ln_final.b": "transformer.ln_f.bias",
82-
"unembed.u": "lm_head.weight",
8367
}
8468

8569
self.component_mapping = {

transformer_lens/model_bridge/supported_architectures/gemma1.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def __init__(self, cfg: Any) -> None:
4646
),
4747
source_key="model.embed_tokens.weight",
4848
),
49-
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
50-
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
5149
"blocks.{i}.attn.q": ParamProcessingConversion(
5250
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
5351
source_key="model.layers.{i}.self_attn.q_proj.weight",
@@ -64,11 +62,6 @@ def __init__(self, cfg: Any) -> None:
6462
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
6563
source_key="model.layers.{i}.self_attn.o_proj.weight",
6664
),
67-
"blocks.{i}.mlp.in": "model.layers.{i}.mlp.up_proj.weight.T",
68-
"blocks.{i}.mlp.gate": "model.layers.{i}.mlp.gate_proj.weight.T",
69-
"blocks.{i}.mlp.out": "model.layers.{i}.mlp.down_proj.weight.T",
70-
"ln_final.w": "model.norm.weight",
71-
"unembed.u": "lm_head.weight.T", # Not shared with embedding
7265
}
7366

7467
self.component_mapping = {

transformer_lens/model_bridge/supported_architectures/gemma2.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ def __init__(self, cfg: Any) -> None:
4949
),
5050
source_key="model.embed_tokens.weight",
5151
),
52-
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
53-
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
5452
"blocks.{i}.attn.q": ParamProcessingConversion(
5553
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
5654
source_key="model.layers.{i}.self_attn.q_proj.weight",
@@ -73,11 +71,6 @@ def __init__(self, cfg: Any) -> None:
7371
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
7472
source_key="model.layers.{i}.self_attn.o_proj.weight",
7573
),
76-
"blocks.{i}.mlp.in": "model.layers.{i}.mlp.up_proj.weight.T",
77-
"blocks.{i}.mlp.gate": "model.layers.{i}.mlp.gate_proj.weight.T",
78-
"blocks.{i}.mlp.out": "model.layers.{i}.mlp.down_proj.weight.T",
79-
"ln_final.w": "model.norm.weight",
80-
"unembed.u": "lm_head.weight.T", # Not shared with embedding
8174
}
8275

8376
self.component_mapping = {

transformer_lens/model_bridge/supported_architectures/gpt2_lm_head_custom.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ def __init__(self, cfg: Any) -> None:
2626
super().__init__(cfg)
2727

2828
self.weight_processing_conversions = {
29-
"pos_embed.pos": "transformer.wpe.weight",
30-
"embed.e": "transformer.wte.weight",
31-
"blocks.{i}.ln1.w": "transformer.h.{i}.ln_1.weight",
32-
"blocks.{i}.ln1.b": "transformer.h.{i}.ln_1.bias",
33-
"blocks.{i}.ln2.w": "transformer.h.{i}.ln_2.weight",
34-
"blocks.{i}.ln2.b": "transformer.h.{i}.ln_2.bias",
3529
"blocks.{i}.attn.q": ParamProcessingConversion(
3630
tensor_conversion=RearrangeTensorConversion(
3731
"d_model (n d_head) -> n d_model d_head"
@@ -68,14 +62,6 @@ def __init__(self, cfg: Any) -> None:
6862
),
6963
source_key="transformer.h.{i}.attn.c_proj.weight",
7064
),
71-
"blocks.{i}.attn.b_O": "transformer.h.{i}.attn.c_proj.bias",
72-
"blocks.{i}.mlp.in": "transformer.h.{i}.mlp.c_fc.weight",
73-
"blocks.{i}.mlp.b_in": "transformer.h.{i}.mlp.c_fc.bias",
74-
"blocks.{i}.mlp.out": "transformer.h.{i}.mlp.c_proj.weight",
75-
"blocks.{i}.mlp.b_out": "transformer.h.{i}.mlp.c_proj.bias",
76-
"ln_final.w": "transformer.ln_f.weight",
77-
"ln_final.b": "transformer.ln_f.bias",
78-
"unembed.u": "lm_head.weight",
7965
# "unembed.b_U": "lm_head.bias", # gpt2 has no unembed bias
8066
}
8167

transformer_lens/model_bridge/supported_architectures/gpt_oss.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ def __init__(self, cfg: Any) -> None:
3939
# Conversion rules for weight processing/folding
4040
# GPT-OSS uses MoE with batched experts, so we need special handling
4141
self.weight_processing_conversions = {
42-
"embed.e": "model.embed_tokens.weight",
43-
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
44-
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
4542
"blocks.{i}.attn.q": ParamProcessingConversion(
4643
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
4744
source_key="model.layers.{i}.self_attn.q_proj.weight",
@@ -58,11 +55,6 @@ def __init__(self, cfg: Any) -> None:
5855
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
5956
source_key="model.layers.{i}.self_attn.o_proj.weight",
6057
),
61-
# Note: MLP weights for MoE models with batched experts are not directly mappable
62-
# The experts use batched tensors [num_experts, ...] which need special handling
63-
# These mappings are for the router only
64-
"ln_final.w": "model.norm.weight",
65-
"unembed.u": "lm_head.weight.T",
6658
}
6759

6860
self.component_mapping = {

transformer_lens/model_bridge/supported_architectures/gptj.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ def __init__(self, cfg: Any) -> None:
3333
self.cfg.attn_only = False
3434

3535
self.weight_processing_conversions = {
36-
"embed.e": "transformer.wte.weight",
37-
"blocks.{i}.ln1.w": "transformer.h.{i}.ln_1.weight",
38-
"blocks.{i}.ln1.b": "transformer.h.{i}.ln_1.bias",
3936
"blocks.{i}.attn.q": ParamProcessingConversion(
4037
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
4138
source_key="transformer.h.{i}.attn.q_proj.weight",
@@ -52,14 +49,6 @@ def __init__(self, cfg: Any) -> None:
5249
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
5350
source_key="transformer.h.{i}.attn.out_proj.weight",
5451
),
55-
"blocks.{i}.mlp.in": "transformer.h.{i}.mlp.fc_in.weight",
56-
"blocks.{i}.mlp.b_in": "transformer.h.{i}.mlp.fc_in.bias",
57-
"blocks.{i}.mlp.out": "transformer.h.{i}.mlp.fc_out.weight",
58-
"blocks.{i}.mlp.b_out": "transformer.h.{i}.mlp.fc_out.bias",
59-
"ln_final.w": "transformer.ln_f.weight",
60-
"ln_final.b": "transformer.ln_f.bias",
61-
"unembed.u": "lm_head.weight",
62-
"unembed.b_U": "lm_head.bias",
6352
}
6453

6554
self.component_mapping = {

transformer_lens/model_bridge/supported_architectures/llama.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ def __init__(self, cfg: Any) -> None:
7171
self.cfg.eps_attr = "variance_epsilon"
7272

7373
self.weight_processing_conversions = {
74-
"embed.e": "model.embed_tokens.weight",
75-
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
76-
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
7774
"blocks.{i}.attn.q": ParamProcessingConversion(
7875
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
7976
source_key="model.layers.{i}.self_attn.q_proj.weight",
@@ -96,11 +93,6 @@ def __init__(self, cfg: Any) -> None:
9693
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
9794
source_key="model.layers.{i}.self_attn.o_proj.weight",
9895
),
99-
"blocks.{i}.mlp.in": "model.layers.{i}.mlp.up_proj.weight.T",
100-
"blocks.{i}.mlp.gate": "model.layers.{i}.mlp.gate_proj.weight.T",
101-
"blocks.{i}.mlp.out": "model.layers.{i}.mlp.down_proj.weight.T",
102-
"ln_final.w": "model.norm.weight",
103-
"unembed.u": "lm_head.weight.T", # Not shared with embedding
10496
}
10597

10698
self.component_mapping = {

transformer_lens/model_bridge/supported_architectures/mingpt.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@ def __init__(self, cfg: Any) -> None:
3131
super().__init__(cfg)
3232

3333
self.weight_processing_conversions = {
34-
"pos_embed.pos": "transformer.wpe.weight",
35-
"embed.e": "transformer.wte.weight",
36-
"blocks.{i}.ln1.w": "transformer.h.{i}.ln_1.weight",
37-
"blocks.{i}.ln1.b": "transformer.h.{i}.ln_1.bias",
38-
"blocks.{i}.ln2.w": "transformer.h.{i}.ln_2.weight",
39-
"blocks.{i}.ln2.b": "transformer.h.{i}.ln_2.bias",
4034
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
4135
tensor_conversion=RearrangeTensorConversion(
4236
"d_model (3 n_head d_head) -> 3 n_head d_head d_model"
@@ -73,15 +67,6 @@ def __init__(self, cfg: Any) -> None:
7367
),
7468
source_key="transformer.h.{i}.attn.c_proj.weight",
7569
),
76-
"blocks.{i}.attn.o.bias": "transformer.h.{i}.attn.c_proj.bias",
77-
"blocks.{i}.mlp.in": "transformer.h.{i}.mlp.c_fc.weight",
78-
"blocks.{i}.mlp.b_in": "transformer.h.{i}.mlp.c_fc.bias",
79-
"blocks.{i}.mlp.out": "transformer.h.{i}.mlp.c_proj.weight",
80-
"blocks.{i}.mlp.b_out": "transformer.h.{i}.mlp.c_proj.bias",
81-
"unembed.u": "lm_head.weight",
82-
"unembed.b_U": "lm_head.bias",
83-
"ln_final.w": "transformer.ln_f.weight",
84-
"ln_final.b": "transformer.ln_f.bias",
8570
}
8671

8772
# Set up component mapping

transformer_lens/model_bridge/supported_architectures/mistral.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ def __init__(self, cfg: Any) -> None:
4545
self.cfg.uses_rms_norm = True
4646

4747
self.weight_processing_conversions = {
48-
"embed.e": "model.embed_tokens.weight",
49-
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
50-
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
5148
"blocks.{i}.attn.q": ParamProcessingConversion(
5249
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
5350
source_key="model.layers.{i}.self_attn.q_proj.weight",
@@ -68,11 +65,6 @@ def __init__(self, cfg: Any) -> None:
6865
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
6966
source_key="model.layers.{i}.self_attn.o_proj.weight",
7067
),
71-
"blocks.{i}.mlp.in": "model.layers.{i}.mlp.up_proj.weight.T",
72-
"blocks.{i}.mlp.gate": "model.layers.{i}.mlp.gate_proj.weight.T",
73-
"blocks.{i}.mlp.out": "model.layers.{i}.mlp.down_proj.weight.T",
74-
"ln_final.w": "model.norm.weight",
75-
"unembed.u": "lm_head.weight.T", # Not shared with embedding
7668
}
7769

7870
self.component_mapping = {

0 commit comments

Comments
 (0)