Skip to content

Commit a6bddfa

Browse files
committed
Add support for GPT-OSS
1 parent 98d649a commit a6bddfa

File tree

6 files changed

+172
-2
lines changed

6 files changed

+172
-2
lines changed

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
GPT2ArchitectureAdapter,
1616
Gpt2LmHeadCustomArchitectureAdapter,
1717
GptjArchitectureAdapter,
18+
GPTOSSArchitectureAdapter,
1819
LlamaArchitectureAdapter,
1920
MingptArchitectureAdapter,
2021
MistralArchitectureAdapter,
@@ -40,6 +41,7 @@
4041
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
4142
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
4243
"GPT2LMHeadModel": GPT2ArchitectureAdapter,
44+
"GptOssForCausalLM": GPTOSSArchitectureAdapter,
4345
"GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
4446
"GPTJForCausalLM": GptjArchitectureAdapter,
4547
"LlamaForCausalLM": LlamaArchitectureAdapter,

transformer_lens/model_bridge/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
EmbeddingBridge,
2424
NormalizationBridge,
2525
JointQKVAttentionBridge,
26+
JointGateUpMLPBridge,
2627
LinearBridge,
2728
MLPBridge,
2829
MoEBridge,
@@ -49,7 +50,7 @@
4950
"EmbeddingBridge",
5051
"NormalizationBridge",
5152
"JointQKVAttentionBridge",
52-
"LinearBridge",
53+
"JointGateUpMLPBridge" "LinearBridge",
5354
"MLPBridge",
5455
"MoEBridge",
5556
"UnembeddingBridge",

transformer_lens/model_bridge/generalized_components/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
2525
JointQKVAttentionBridge,
2626
)
27+
from transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp import (
28+
JointGateUpMLPBridge,
29+
)
2730
from transformer_lens.model_bridge.generalized_components.unembedding import (
2831
UnembeddingBridge,
2932
)
@@ -34,7 +37,7 @@
3437
"EmbeddingBridge",
3538
"NormalizationBridge",
3639
"JointQKVAttentionBridge",
37-
"LinearBridge",
40+
"JointGateUpMLPBridge" "LinearBridge",
3841
"MLPBridge",
3942
"MoEBridge",
4043
"JointQKVAttentionBridge",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""MLP bridge component.
2+
3+
This module contains the bridge component for MLP layers with joint gating and up-projection.
4+
"""
5+
6+
7+
from typing import Any, Dict, Optional
8+
9+
import torch
10+
11+
from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge
12+
13+
14+
class JointGateUpMLPBridge(MLPBridge):
15+
"""Bridge component for MLP layers with joint gating and up-projections.
16+
17+
This component wraps an MLP layer with fused gate and up projections such that both the activations
18+
from the joint projection and the seperate gate and up projections are hooked and accessible.
19+
"""
20+
21+
def __init__(
22+
self,
23+
name: str,
24+
model_config: Optional[Any] = None,
25+
submodules: Optional[Dict[str, GeneralizedComponent]] = {},
26+
gate_up_config: Optional[Dict[str, Any]] = None,
27+
):
28+
"""Initialize the JointGateUpMLP bridge.
29+
30+
Args:
31+
name: The name of the component in the model
32+
model_config: Optional configuration (unused for MLPBridge)
33+
submodules: Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
34+
gate_up_config: Gate_Up-specific configuration which holds function to split the joint projection into two
35+
"""
36+
super().__init__(name, model_config, submodules=submodules)
37+
self.gate = LinearBridge(name="gate", config=model_config)
38+
self.up = LinearBridge(name="up", config=model_config)
39+
40+
def set_original_component(self, original_component: torch.nn.Module) -> None:
41+
"""Set the original MLP component and initialize LinearBridges for gate and up projections.
42+
43+
Args:
44+
original_component: The original MLP component to wrap
45+
"""
46+
super().set_original_component(original_component)
47+
48+
Gate_projection, Up_projection = self.gate_up_config["split_gate_up_matrix"](
49+
original_component
50+
)
51+
52+
# Initialize the LinearBridges for the seperated gate and up projections
53+
self.gate.set_original_component(Gate_projection)
54+
self.up.set_original_component(Up_projection)
55+
56+
def forward(self, *args, **kwargs) -> torch.Tensor:
57+
"""Forward pass through the JointGateUpMLP bridge.
58+
59+
Args:
60+
*args: Positional arguments for the original component
61+
**kwargs: Keyword arguments for the original component
62+
63+
Returns:
64+
Output hidden states
65+
"""
66+
output = super().forward(*args, **kwargs)
67+
68+
# Extract input tensor to run through gate and up projections
69+
# in order to hook their outputs
70+
input_tensor = (
71+
args[0] if len(args) > 0 else kwargs.get("input", kwargs.get("hidden_states"))
72+
)
73+
if input_tensor is not None:
74+
gated_output = self.gate(input_tensor)
75+
self.up(gated_output)
76+
77+
return output

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
2222
GPT2ArchitectureAdapter,
2323
)
24+
from transformer_lens.model_bridge.supported_architectures.gpt_oss import GPTOSSArchitectureAdapter
2425
from transformer_lens.model_bridge.supported_architectures.gpt2_lm_head_custom import (
2526
Gpt2LmHeadCustomArchitectureAdapter,
2627
)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""GPT-OSS architecture adapter."""
2+
3+
from typing import Any
4+
5+
import torch
6+
7+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
8+
from transformer_lens.model_bridge.generalized_components import (
9+
BlockBridge,
10+
EmbeddingBridge,
11+
JointGateUpMLPBridge,
12+
LinearBridge,
13+
MLPBridge,
14+
NormalizationBridge,
15+
UnembeddingBridge,
16+
)
17+
18+
19+
class GPTOSSArchitectureAdapter(ArchitectureAdapter):
20+
"""Architecture adapter for GPT-OSS model."""
21+
22+
def __init__(self, cfg: Any) -> None:
23+
"""Initialize the GPT-OSS architecture adapter."""
24+
super().__init__(cfg)
25+
26+
self.component_mapping = {
27+
"embed": EmbeddingBridge(name="model.embed_tokens"),
28+
"rotary_emb": EmbeddingBridge(name="model.rotary_emb"),
29+
"blocks": BlockBridge(
30+
name="model.layers",
31+
submodules={
32+
"ln1": NormalizationBridge(name="input_layernorm"),
33+
"attn": AttentionBridge(
34+
name="self_attn",
35+
submodules={
36+
"q": LinearBridge(name="q_proj"),
37+
"k": LinearBridge(name="k_proj"),
38+
"v": LinearBridge(name="v_proj"),
39+
"o": LinearBridge(name="o_proj"),
40+
},
41+
),
42+
"ln2": NormalizationBridge(name="post_attention_layernorm"),
43+
"mlp": MLPBridge(
44+
name="mlp",
45+
submodules={
46+
"router": LinearBridge(name="router"),
47+
"experts": BlockBridge(
48+
name="experts",
49+
submodules={
50+
"gate_up": JointGateUpMLPBridge(name="gate_up_proj"),
51+
"down": LinearBridge(name="down_proj"),
52+
},
53+
),
54+
},
55+
),
56+
},
57+
),
58+
"ln_final": NormalizationBridge(name="model.norm"),
59+
"unembed": UnembeddingBridge(name="lm_head"),
60+
}
61+
62+
def split_gate_up_matrix(
63+
self, original_mlp_component: Any
64+
) -> tuple[torch.nn.Linear, torch.nn.Linear]:
65+
gate_up_weight = original_mlp_component.gate_up_proj
66+
gate_up_bias = original_mlp_component.gate_up_proj_bias
67+
68+
# In GPT-OSS, all the gate projection weights lie at even indices,
69+
# all the up projection weights lie at odd indices
70+
gate_weight = gate_up_weight[..., ::2]
71+
up_weight = gate_up_weight[..., 1::2]
72+
73+
gate_bias = gate_up_bias[..., ::2]
74+
up_bias = gate_up_bias[..., 1::2]
75+
76+
gate_projection = torch.nn.Linear(gate_weight.shape[0], gate_weight.shape[1], bias=True)
77+
78+
gate_projection.weight = torch.nn.Parameter(gate_weight)
79+
gate_projection.bias = torch.nn.Parameter(bias)
80+
81+
up_projection = torch.nn.Linear(up_weight.shape[0], up_weight.shape[1])
82+
83+
up_projection.weight = torch.nn.Parameter(up_weight)
84+
up_projection.bias = torch.nn.Parameter(up_bias)
85+
86+
return gate_projection, up_projection

0 commit comments

Comments
 (0)