66
77from transformer_lens .model_bridge .architecture_adapter import ArchitectureAdapter
88from transformer_lens .model_bridge .generalized_components import (
9+ AttentionBridge ,
910 BlockBridge ,
1011 EmbeddingBridge ,
1112 JointGateUpMLPBridge ,
@@ -32,6 +33,7 @@ def __init__(self, cfg: Any) -> None:
3233 "ln1" : NormalizationBridge (name = "input_layernorm" ),
3334 "attn" : AttentionBridge (
3435 name = "self_attn" ,
36+ config = self .cfg ,
3537 submodules = {
3638 "q" : LinearBridge (name = "q_proj" ),
3739 "k" : LinearBridge (name = "k_proj" ),
@@ -64,28 +66,28 @@ def __init__(self, cfg: Any) -> None:
6466 "unembed" : UnembeddingBridge (name = "lm_head" ),
6567 }
6668
67- def split_gate_up_matrix (
68- self , original_mlp_component : Any
69- ) -> tuple [torch .nn .Linear , torch .nn .Linear ]:
70- gate_up_weight = original_mlp_component .gate_up_proj
71- gate_up_bias = original_mlp_component .gate_up_proj_bias
69+ def split_gate_up_matrix (
70+ self , original_mlp_component : Any
71+ ) -> tuple [torch .nn .Linear , torch .nn .Linear ]:
72+ gate_up_weight = original_mlp_component .gate_up_proj
73+ gate_up_bias = original_mlp_component .gate_up_proj_bias
7274
73- # In GPT-OSS, all the gate projection weights lie at even indices,
74- # all the up projection weights lie at odd indices
75- gate_weight = gate_up_weight [..., ::2 ]
76- up_weight = gate_up_weight [..., 1 ::2 ]
75+ # In GPT-OSS, all the gate projection weights lie at even indices,
76+ # all the up projection weights lie at odd indices
77+ gate_weight = gate_up_weight [..., ::2 ]
78+ up_weight = gate_up_weight [..., 1 ::2 ]
7779
78- gate_bias = gate_up_bias [..., ::2 ]
79- up_bias = gate_up_bias [..., 1 ::2 ]
80+ gate_bias = gate_up_bias [..., ::2 ]
81+ up_bias = gate_up_bias [..., 1 ::2 ]
8082
81- gate_projection = torch .nn .Linear (gate_weight .shape [0 ], gate_weight .shape [1 ], bias = True )
83+ gate_projection = torch .nn .Linear (gate_weight .shape [0 ], gate_weight .shape [1 ], bias = True )
8284
83- gate_projection .weight = torch .nn .Parameter (gate_weight )
84- gate_projection .bias = torch .nn .Parameter (bias )
85+ gate_projection .weight = torch .nn .Parameter (gate_weight )
86+ gate_projection .bias = torch .nn .Parameter (gate_bias )
8587
86- up_projection = torch .nn .Linear (up_weight .shape [0 ], up_weight .shape [1 ])
88+ up_projection = torch .nn .Linear (up_weight .shape [0 ], up_weight .shape [1 ])
8789
88- up_projection .weight = torch .nn .Parameter (up_weight )
89- up_projection .bias = torch .nn .Parameter (up_bias )
90+ up_projection .weight = torch .nn .Parameter (up_weight )
91+ up_projection .bias = torch .nn .Parameter (up_bias )
9092
91- return gate_projection , up_projection
93+ return gate_projection , up_projection
0 commit comments