Skip to content

Commit eff0180

Browse files
Add support for GPT-OSS (#1004)
* Add support for GPT-OSS * Add conversion method to config in gpt-oss architecture adapter * Fix missing comma * fixed doc string issues * fix missing import --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent b80775f commit eff0180

File tree

6 files changed

+185
-0
lines changed

6 files changed

+185
-0
lines changed

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
GPT2ArchitectureAdapter,
1616
Gpt2LmHeadCustomArchitectureAdapter,
1717
GptjArchitectureAdapter,
18+
GPTOSSArchitectureAdapter,
1819
LlamaArchitectureAdapter,
1920
MingptArchitectureAdapter,
2021
MistralArchitectureAdapter,
@@ -40,6 +41,7 @@
4041
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
4142
"Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
4243
"GPT2LMHeadModel": GPT2ArchitectureAdapter,
44+
"GptOssForCausalLM": GPTOSSArchitectureAdapter,
4345
"GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
4446
"GPTJForCausalLM": GptjArchitectureAdapter,
4547
"LlamaForCausalLM": LlamaArchitectureAdapter,

transformer_lens/model_bridge/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
EmbeddingBridge,
2424
NormalizationBridge,
2525
JointQKVAttentionBridge,
26+
JointGateUpMLPBridge,
2627
LinearBridge,
2728
MLPBridge,
2829
MoEBridge,
@@ -49,6 +50,7 @@
4950
"EmbeddingBridge",
5051
"NormalizationBridge",
5152
"JointQKVAttentionBridge",
53+
"JointGateUpMLPBridge",
5254
"LinearBridge",
5355
"MLPBridge",
5456
"MoEBridge",

transformer_lens/model_bridge/generalized_components/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
2525
JointQKVAttentionBridge,
2626
)
27+
from transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp import (
28+
JointGateUpMLPBridge,
29+
)
2730
from transformer_lens.model_bridge.generalized_components.unembedding import (
2831
UnembeddingBridge,
2932
)
@@ -34,6 +37,7 @@
3437
"EmbeddingBridge",
3538
"NormalizationBridge",
3639
"JointQKVAttentionBridge",
40+
"JointGateUpMLPBridge",
3741
"LinearBridge",
3842
"MLPBridge",
3943
"MoEBridge",
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""MLP bridge component.
2+
3+
This module contains the bridge component for MLP layers with joint gating and up-projection.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from typing import Any, Dict, Optional
9+
10+
import torch
11+
12+
from transformer_lens.model_bridge.generalized_components.base import (
13+
GeneralizedComponent,
14+
)
15+
from transformer_lens.model_bridge.generalized_components.linear import LinearBridge
16+
from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge
17+
18+
19+
class JointGateUpMLPBridge(MLPBridge):
20+
"""Bridge component for MLP layers with joint gating and up-projections.
21+
22+
This component wraps an MLP layer with fused gate and up projections such that both the activations
23+
from the joint projection and the seperate gate and up projections are hooked and accessible.
24+
"""
25+
26+
def __init__(
27+
self,
28+
name: str,
29+
model_config: Optional[Any] = None,
30+
submodules: Optional[Dict[str, GeneralizedComponent]] = {},
31+
gate_up_config: Optional[Dict[str, Any]] = None,
32+
):
33+
"""Initialize the JointGateUpMLP bridge.
34+
35+
Args:
36+
name: The name of the component in the model
37+
model_config: Optional configuration (unused for MLPBridge)
38+
submodules: Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
39+
gate_up_config: Gate_Up-specific configuration which holds function to split the joint projection into two
40+
"""
41+
super().__init__(name, model_config, submodules=submodules)
42+
self.gate_up_config = gate_up_config or {}
43+
self.gate = LinearBridge(name="gate", config=model_config)
44+
self.up = LinearBridge(name="up", config=model_config)
45+
46+
def set_original_component(self, original_component: torch.nn.Module) -> None:
47+
"""Set the original MLP component and initialize LinearBridges for gate and up projections.
48+
49+
Args:
50+
original_component: The original MLP component to wrap
51+
"""
52+
super().set_original_component(original_component)
53+
54+
Gate_projection, Up_projection = self.gate_up_config["split_gate_up_matrix"](
55+
original_component
56+
)
57+
58+
# Initialize the LinearBridges for the seperated gate and up projections
59+
self.gate.set_original_component(Gate_projection)
60+
self.up.set_original_component(Up_projection)
61+
62+
def forward(self, *args, **kwargs) -> torch.Tensor:
63+
"""Forward pass through the JointGateUpMLP bridge.
64+
65+
Args:
66+
*args: Positional arguments for the original component
67+
**kwargs: Keyword arguments for the original component
68+
69+
Returns:
70+
Output hidden states
71+
"""
72+
output = super().forward(*args, **kwargs)
73+
74+
# Extract input tensor to run through gate and up projections
75+
# in order to hook their outputs
76+
input_tensor = (
77+
args[0] if len(args) > 0 else kwargs.get("input", kwargs.get("hidden_states"))
78+
)
79+
if input_tensor is not None:
80+
gated_output = self.gate(input_tensor)
81+
self.up(gated_output)
82+
83+
return output

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
2222
GPT2ArchitectureAdapter,
2323
)
24+
from transformer_lens.model_bridge.supported_architectures.gpt_oss import GPTOSSArchitectureAdapter
2425
from transformer_lens.model_bridge.supported_architectures.gpt2_lm_head_custom import (
2526
Gpt2LmHeadCustomArchitectureAdapter,
2627
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""GPT-OSS architecture adapter."""
2+
3+
from typing import Any
4+
5+
import torch
6+
7+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
8+
from transformer_lens.model_bridge.generalized_components import (
9+
AttentionBridge,
10+
BlockBridge,
11+
EmbeddingBridge,
12+
JointGateUpMLPBridge,
13+
LinearBridge,
14+
MLPBridge,
15+
NormalizationBridge,
16+
UnembeddingBridge,
17+
)
18+
19+
20+
class GPTOSSArchitectureAdapter(ArchitectureAdapter):
21+
"""Architecture adapter for GPT-OSS model."""
22+
23+
def __init__(self, cfg: Any) -> None:
24+
"""Initialize the GPT-OSS architecture adapter."""
25+
super().__init__(cfg)
26+
27+
self.component_mapping = {
28+
"embed": EmbeddingBridge(name="model.embed_tokens"),
29+
"rotary_emb": EmbeddingBridge(name="model.rotary_emb"),
30+
"blocks": BlockBridge(
31+
name="model.layers",
32+
submodules={
33+
"ln1": NormalizationBridge(name="input_layernorm"),
34+
"attn": AttentionBridge(
35+
name="self_attn",
36+
config=self.cfg,
37+
submodules={
38+
"q": LinearBridge(name="q_proj"),
39+
"k": LinearBridge(name="k_proj"),
40+
"v": LinearBridge(name="v_proj"),
41+
"o": LinearBridge(name="o_proj"),
42+
},
43+
),
44+
"ln2": NormalizationBridge(name="post_attention_layernorm"),
45+
"mlp": MLPBridge(
46+
name="mlp",
47+
submodules={
48+
"router": LinearBridge(name="router"),
49+
"experts": BlockBridge(
50+
name="experts",
51+
submodules={
52+
"gate_up": JointGateUpMLPBridge(
53+
name="gate_up_proj",
54+
gate_up_config={
55+
"split_gate_up_matrix": self.split_gate_up_matrix
56+
},
57+
),
58+
"down": LinearBridge(name="down_proj"),
59+
},
60+
),
61+
},
62+
),
63+
},
64+
),
65+
"ln_final": NormalizationBridge(name="model.norm"),
66+
"unembed": UnembeddingBridge(name="lm_head"),
67+
}
68+
69+
def split_gate_up_matrix(
70+
self, original_mlp_component: Any
71+
) -> tuple[torch.nn.Linear, torch.nn.Linear]:
72+
gate_up_weight = original_mlp_component.gate_up_proj
73+
gate_up_bias = original_mlp_component.gate_up_proj_bias
74+
75+
# In GPT-OSS, all the gate projection weights lie at even indices,
76+
# all the up projection weights lie at odd indices
77+
gate_weight = gate_up_weight[..., ::2]
78+
up_weight = gate_up_weight[..., 1::2]
79+
80+
gate_bias = gate_up_bias[..., ::2]
81+
up_bias = gate_up_bias[..., 1::2]
82+
83+
gate_projection = torch.nn.Linear(gate_weight.shape[0], gate_weight.shape[1], bias=True)
84+
85+
gate_projection.weight = torch.nn.Parameter(gate_weight)
86+
gate_projection.bias = torch.nn.Parameter(gate_bias)
87+
88+
up_projection = torch.nn.Linear(up_weight.shape[0], up_weight.shape[1])
89+
90+
up_projection.weight = torch.nn.Parameter(up_weight)
91+
up_projection.bias = torch.nn.Parameter(up_bias)
92+
93+
return gate_projection, up_projection

0 commit comments

Comments
 (0)