Skip to content

Commit 9c6fdf1

Browse files
committed
fix missing import
1 parent c2e2d90 commit 9c6fdf1

File tree

1 file changed

+20
-18
lines changed
  • transformer_lens/model_bridge/supported_architectures

1 file changed

+20
-18
lines changed

transformer_lens/model_bridge/supported_architectures/gpt_oss.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
88
from transformer_lens.model_bridge.generalized_components import (
9+
AttentionBridge,
910
BlockBridge,
1011
EmbeddingBridge,
1112
JointGateUpMLPBridge,
@@ -32,6 +33,7 @@ def __init__(self, cfg: Any) -> None:
3233
"ln1": NormalizationBridge(name="input_layernorm"),
3334
"attn": AttentionBridge(
3435
name="self_attn",
36+
config=self.cfg,
3537
submodules={
3638
"q": LinearBridge(name="q_proj"),
3739
"k": LinearBridge(name="k_proj"),
@@ -64,28 +66,28 @@ def __init__(self, cfg: Any) -> None:
6466
"unembed": UnembeddingBridge(name="lm_head"),
6567
}
6668

67-
def split_gate_up_matrix(
68-
self, original_mlp_component: Any
69-
) -> tuple[torch.nn.Linear, torch.nn.Linear]:
70-
gate_up_weight = original_mlp_component.gate_up_proj
71-
gate_up_bias = original_mlp_component.gate_up_proj_bias
69+
def split_gate_up_matrix(
70+
self, original_mlp_component: Any
71+
) -> tuple[torch.nn.Linear, torch.nn.Linear]:
72+
gate_up_weight = original_mlp_component.gate_up_proj
73+
gate_up_bias = original_mlp_component.gate_up_proj_bias
7274

73-
# In GPT-OSS, all the gate projection weights lie at even indices,
74-
# all the up projection weights lie at odd indices
75-
gate_weight = gate_up_weight[..., ::2]
76-
up_weight = gate_up_weight[..., 1::2]
75+
# In GPT-OSS, all the gate projection weights lie at even indices,
76+
# all the up projection weights lie at odd indices
77+
gate_weight = gate_up_weight[..., ::2]
78+
up_weight = gate_up_weight[..., 1::2]
7779

78-
gate_bias = gate_up_bias[..., ::2]
79-
up_bias = gate_up_bias[..., 1::2]
80+
gate_bias = gate_up_bias[..., ::2]
81+
up_bias = gate_up_bias[..., 1::2]
8082

81-
gate_projection = torch.nn.Linear(gate_weight.shape[0], gate_weight.shape[1], bias=True)
83+
gate_projection = torch.nn.Linear(gate_weight.shape[0], gate_weight.shape[1], bias=True)
8284

83-
gate_projection.weight = torch.nn.Parameter(gate_weight)
84-
gate_projection.bias = torch.nn.Parameter(bias)
85+
gate_projection.weight = torch.nn.Parameter(gate_weight)
86+
gate_projection.bias = torch.nn.Parameter(gate_bias)
8587

86-
up_projection = torch.nn.Linear(up_weight.shape[0], up_weight.shape[1])
88+
up_projection = torch.nn.Linear(up_weight.shape[0], up_weight.shape[1])
8789

88-
up_projection.weight = torch.nn.Parameter(up_weight)
89-
up_projection.bias = torch.nn.Parameter(up_bias)
90+
up_projection.weight = torch.nn.Parameter(up_weight)
91+
up_projection.bias = torch.nn.Parameter(up_bias)
9092

91-
return gate_projection, up_projection
93+
return gate_projection, up_projection

0 commit comments

Comments
 (0)