|
2 | 2 |
|
3 | 3 | from typing import Any |
4 | 4 |
|
| 5 | +import torch |
| 6 | + |
5 | 7 | from transformer_lens.conversion_utils.conversion_steps import ( |
6 | 8 | HookConversionSet, |
7 | 9 | RearrangeHookConversion, |
|
12 | 14 | ) |
13 | 15 | from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter |
14 | 16 | from transformer_lens.model_bridge.generalized_components import ( |
15 | | - AttentionBridge, |
16 | 17 | BlockBridge, |
17 | 18 | EmbeddingBridge, |
| 19 | + JointQKVAttentionBridge, |
| 20 | + LinearBridge, |
18 | 21 | MLPBridge, |
19 | 22 | NormalizationBridge, |
20 | 23 | UnembeddingBridge, |
@@ -132,15 +135,81 @@ def __init__(self, cfg: Any) -> None: |
132 | 135 |
|
133 | 136 | self.component_mapping = { |
134 | 137 | "embed": EmbeddingBridge(name="gpt_neox.embed_in"), |
| 138 | + "rotary_emb": EmbeddingBridge(name="gpt_neox.rotary_emb"), |
135 | 139 | "blocks": BlockBridge( |
136 | 140 | name="gpt_neox.layers", |
137 | 141 | submodules={ |
138 | 142 | "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), |
139 | 143 | "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg), |
140 | | - "attn": AttentionBridge(name="attention", config=self.cfg), |
141 | | - "mlp": MLPBridge(name="mlp"), |
| 144 | + "attn": JointQKVAttentionBridge( |
| 145 | + name="attention", |
| 146 | + config=self.cfg, |
| 147 | + split_qkv_matrix=self.split_qkv_matrix, |
| 148 | + submodules={ |
| 149 | + "qkv": LinearBridge(name="query_key_value"), |
| 150 | + "o": LinearBridge(name="dense"), |
| 151 | + }, |
| 152 | + ), |
| 153 | + "mlp": MLPBridge( |
| 154 | + name="mlp", |
| 155 | + submodules={ |
| 156 | + "in": LinearBridge(name="dense_h_to_4h"), |
| 157 | + "out": LinearBridge(name="dense_4h_to_h"), |
| 158 | + }, |
| 159 | + ), |
142 | 160 | }, |
143 | 161 | ), |
144 | 162 | "ln_final": NormalizationBridge(name="gpt_neox.final_layer_norm", config=self.cfg), |
145 | 163 | "unembed": UnembeddingBridge(name="embed_out"), |
146 | 164 | } |
| 165 | + |
| 166 | + def split_qkv_matrix( |
| 167 | + self, original_attention_component: Any |
| 168 | + ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]: |
| 169 | + """Split the QKV matrix into separate linear transformations. |
| 170 | + Args: |
| 171 | + attention_component: The original attention layer component |
| 172 | + Returns: |
| 173 | + Tuple of nn.Linear modules for Q, K, and V transformations |
| 174 | + """ |
| 175 | + |
| 176 | + # Keep mypy happy |
| 177 | + assert original_attention_component is not None |
| 178 | + assert original_attention_component.query_key_value is not None |
| 179 | + |
| 180 | + qkv_weights = original_attention_component.query_key_value.weight |
| 181 | + |
| 182 | + # Keep mypy happy |
| 183 | + assert isinstance(qkv_weights, torch.Tensor) |
| 184 | + |
| 185 | + # Original qkv_weights shape: [3 * d_model, d_model] -> Transposed to [d_model, 3 * d_model] |
| 186 | + # Split into three equal parts along dimension 1 to get Q, K, V weights |
| 187 | + W_Q, W_K, W_V = torch.tensor_split(qkv_weights.T, 3, dim=1) |
| 188 | + |
| 189 | + qkv_bias = original_attention_component.query_key_value.bias |
| 190 | + |
| 191 | + # Keep mypy happy |
| 192 | + assert isinstance(qkv_bias, torch.Tensor) |
| 193 | + |
| 194 | + # Original qkv_bias shape: [n_heads * 3 * d_head] |
| 195 | + # Reshape to [3, n_heads * d_head] to split by Q, K, V |
| 196 | + qkv_bias = qkv_bias.reshape(3, self.cfg.n_heads * self.cfg.d_head) |
| 197 | + b_Q, b_K, b_V = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :] |
| 198 | + |
| 199 | + # Create nn.Linear modules |
| 200 | + # After tensor_split, W_Q, W_K, W_V shapes are [d_model, d_model] ([in_features, out_features]) |
| 201 | + # nn.Linear expects weight shape [out_features, in_features] |
| 202 | + # So we need to transpose the weights |
| 203 | + W_Q_transformation = torch.nn.Linear(W_Q.shape[0], W_Q.shape[1], bias=True) |
| 204 | + W_Q_transformation.weight = torch.nn.Parameter(W_Q.T) |
| 205 | + W_Q_transformation.bias = torch.nn.Parameter(b_Q) |
| 206 | + |
| 207 | + W_K_transformation = torch.nn.Linear(W_K.shape[0], W_K.shape[1], bias=True) |
| 208 | + W_K_transformation.weight = torch.nn.Parameter(W_K.T) |
| 209 | + W_K_transformation.bias = torch.nn.Parameter(b_K) |
| 210 | + |
| 211 | + W_V_transformation = torch.nn.Linear(W_V.shape[0], W_V.shape[1], bias=True) |
| 212 | + W_V_transformation.weight = torch.nn.Parameter(W_V.T) |
| 213 | + W_V_transformation.bias = torch.nn.Parameter(b_V) |
| 214 | + |
| 215 | + return W_Q_transformation, W_K_transformation, W_V_transformation |
0 commit comments