Skip to content

Commit 680d4e7

Browse files
committed
Fix initialization error
1 parent 9c6fdf1 commit 680d4e7

File tree

4 files changed

+1
-128
lines changed

4 files changed

+1
-128
lines changed

transformer_lens/model_bridge/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
EmbeddingBridge,
2424
NormalizationBridge,
2525
JointQKVAttentionBridge,
26-
JointGateUpMLPBridge,
2726
LinearBridge,
2827
MLPBridge,
2928
MoEBridge,
@@ -50,7 +49,6 @@
5049
"EmbeddingBridge",
5150
"NormalizationBridge",
5251
"JointQKVAttentionBridge",
53-
"JointGateUpMLPBridge",
5452
"LinearBridge",
5553
"MLPBridge",
5654
"MoEBridge",

transformer_lens/model_bridge/generalized_components/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
2525
JointQKVAttentionBridge,
2626
)
27-
from transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp import (
28-
JointGateUpMLPBridge,
29-
)
3027
from transformer_lens.model_bridge.generalized_components.unembedding import (
3128
UnembeddingBridge,
3229
)
@@ -37,7 +34,6 @@
3734
"EmbeddingBridge",
3835
"NormalizationBridge",
3936
"JointQKVAttentionBridge",
40-
"JointGateUpMLPBridge",
4137
"LinearBridge",
4238
"MLPBridge",
4339
"MoEBridge",

transformer_lens/model_bridge/generalized_components/joint_gate_up_mlp.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

transformer_lens/model_bridge/supported_architectures/gpt_oss.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22

33
from typing import Any
44

5-
import torch
6-
75
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
86
from transformer_lens.model_bridge.generalized_components import (
97
AttentionBridge,
108
BlockBridge,
119
EmbeddingBridge,
12-
JointGateUpMLPBridge,
1310
LinearBridge,
1411
MLPBridge,
1512
NormalizationBridge,
@@ -46,17 +43,8 @@ def __init__(self, cfg: Any) -> None:
4643
name="mlp",
4744
submodules={
4845
"router": LinearBridge(name="router"),
49-
"experts": BlockBridge(
46+
"experts": MLPBridge(
5047
name="experts",
51-
submodules={
52-
"gate_up": JointGateUpMLPBridge(
53-
name="gate_up_proj",
54-
gate_up_config={
55-
"split_gate_up_matrix": self.split_gate_up_matrix
56-
},
57-
),
58-
"down": LinearBridge(name="down_proj"),
59-
},
6048
),
6149
},
6250
),
@@ -65,29 +53,3 @@ def __init__(self, cfg: Any) -> None:
6553
"ln_final": NormalizationBridge(name="model.norm"),
6654
"unembed": UnembeddingBridge(name="lm_head"),
6755
}
68-
69-
def split_gate_up_matrix(
70-
self, original_mlp_component: Any
71-
) -> tuple[torch.nn.Linear, torch.nn.Linear]:
72-
gate_up_weight = original_mlp_component.gate_up_proj
73-
gate_up_bias = original_mlp_component.gate_up_proj_bias
74-
75-
# In GPT-OSS, all the gate projection weights lie at even indices,
76-
# all the up projection weights lie at odd indices
77-
gate_weight = gate_up_weight[..., ::2]
78-
up_weight = gate_up_weight[..., 1::2]
79-
80-
gate_bias = gate_up_bias[..., ::2]
81-
up_bias = gate_up_bias[..., 1::2]
82-
83-
gate_projection = torch.nn.Linear(gate_weight.shape[0], gate_weight.shape[1], bias=True)
84-
85-
gate_projection.weight = torch.nn.Parameter(gate_weight)
86-
gate_projection.bias = torch.nn.Parameter(gate_bias)
87-
88-
up_projection = torch.nn.Linear(up_weight.shape[0], up_weight.shape[1])
89-
90-
up_projection.weight = torch.nn.Parameter(up_weight)
91-
up_projection.bias = torch.nn.Parameter(up_bias)
92-
93-
return gate_projection, up_projection

0 commit comments

Comments
 (0)