Skip to content

Commit ff89e67

Browse files
authored
Create bridge for every module in pythia (#1060)
1 parent 5d662ce commit ff89e67

File tree

1 file changed

+72
-3
lines changed
  • transformer_lens/model_bridge/supported_architectures

1 file changed

+72
-3
lines changed

transformer_lens/model_bridge/supported_architectures/pythia.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Any
44

5+
import torch
6+
57
from transformer_lens.conversion_utils.conversion_steps import (
68
HookConversionSet,
79
RearrangeHookConversion,
@@ -12,9 +14,10 @@
1214
)
1315
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
1416
from transformer_lens.model_bridge.generalized_components import (
15-
AttentionBridge,
1617
BlockBridge,
1718
EmbeddingBridge,
19+
JointQKVAttentionBridge,
20+
LinearBridge,
1821
MLPBridge,
1922
NormalizationBridge,
2023
UnembeddingBridge,
@@ -132,15 +135,81 @@ def __init__(self, cfg: Any) -> None:
132135

133136
self.component_mapping = {
134137
"embed": EmbeddingBridge(name="gpt_neox.embed_in"),
138+
"rotary_emb": EmbeddingBridge(name="gpt_neox.rotary_emb"),
135139
"blocks": BlockBridge(
136140
name="gpt_neox.layers",
137141
submodules={
138142
"ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
139143
"ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg),
140-
"attn": AttentionBridge(name="attention", config=self.cfg),
141-
"mlp": MLPBridge(name="mlp"),
144+
"attn": JointQKVAttentionBridge(
145+
name="attention",
146+
config=self.cfg,
147+
split_qkv_matrix=self.split_qkv_matrix,
148+
submodules={
149+
"qkv": LinearBridge(name="query_key_value"),
150+
"o": LinearBridge(name="dense"),
151+
},
152+
),
153+
"mlp": MLPBridge(
154+
name="mlp",
155+
submodules={
156+
"in": LinearBridge(name="dense_h_to_4h"),
157+
"out": LinearBridge(name="dense_4h_to_h"),
158+
},
159+
),
142160
},
143161
),
144162
"ln_final": NormalizationBridge(name="gpt_neox.final_layer_norm", config=self.cfg),
145163
"unembed": UnembeddingBridge(name="embed_out"),
146164
}
165+
166+
def split_qkv_matrix(
167+
self, original_attention_component: Any
168+
) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
169+
"""Split the QKV matrix into separate linear transformations.
170+
Args:
171+
attention_component: The original attention layer component
172+
Returns:
173+
Tuple of nn.Linear modules for Q, K, and V transformations
174+
"""
175+
176+
# Keep mypy happy
177+
assert original_attention_component is not None
178+
assert original_attention_component.query_key_value is not None
179+
180+
qkv_weights = original_attention_component.query_key_value.weight
181+
182+
# Keep mypy happy
183+
assert isinstance(qkv_weights, torch.Tensor)
184+
185+
# Original qkv_weights shape: [3 * d_model, d_model] -> Transposed to [d_model, 3 * d_model]
186+
# Split into three equal parts along dimension 1 to get Q, K, V weights
187+
W_Q, W_K, W_V = torch.tensor_split(qkv_weights.T, 3, dim=1)
188+
189+
qkv_bias = original_attention_component.query_key_value.bias
190+
191+
# Keep mypy happy
192+
assert isinstance(qkv_bias, torch.Tensor)
193+
194+
# Original qkv_bias shape: [n_heads * 3 * d_head]
195+
# Reshape to [3, n_heads * d_head] to split by Q, K, V
196+
qkv_bias = qkv_bias.reshape(3, self.cfg.n_heads * self.cfg.d_head)
197+
b_Q, b_K, b_V = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :]
198+
199+
# Create nn.Linear modules
200+
# After tensor_split, W_Q, W_K, W_V shapes are [d_model, d_model] ([in_features, out_features])
201+
# nn.Linear expects weight shape [out_features, in_features]
202+
# So we need to transpose the weights
203+
W_Q_transformation = torch.nn.Linear(W_Q.shape[0], W_Q.shape[1], bias=True)
204+
W_Q_transformation.weight = torch.nn.Parameter(W_Q.T)
205+
W_Q_transformation.bias = torch.nn.Parameter(b_Q)
206+
207+
W_K_transformation = torch.nn.Linear(W_K.shape[0], W_K.shape[1], bias=True)
208+
W_K_transformation.weight = torch.nn.Parameter(W_K.T)
209+
W_K_transformation.bias = torch.nn.Parameter(b_K)
210+
211+
W_V_transformation = torch.nn.Linear(W_V.shape[0], W_V.shape[1], bias=True)
212+
W_V_transformation.weight = torch.nn.Parameter(W_V.T)
213+
W_V_transformation.bias = torch.nn.Parameter(b_V)
214+
215+
return W_Q_transformation, W_K_transformation, W_V_transformation

0 commit comments

Comments
 (0)