Skip to content

Commit 4bed46d

Browse files
Attn pattern shape (#1029)
* Move QKV separation to bridge that directly wraps QKV matrix * Fix typing issues * Fix hook collection issues * Ensuring standardized hook shape * Fix syntax error * Run CI again * adjust test to reflect new hook names in qkv bridge * simplify getattr in base component * Add parameter for conversion rule of hook_in and hook_out in qkvbridge * moved hook point wrapper, and added more test coverage * matched attn pattern shape to hooked transformer * revised hook pattern application * updated outdated test --------- Co-authored-by: degenfabian <[email protected]> Co-authored-by: Fabian Degen <[email protected]>
1 parent 2c24781 commit 4bed46d

File tree

3 files changed

+235
-28
lines changed

3 files changed

+235
-28
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_attention_pattern_hook_shape_custom_conversion():
249249

250250

251251
def test_attention_pattern_hook_shape():
252-
"""Test that the attention pattern hook produces the correct shape (batch, n_heads, pos, pos)."""
252+
"""Test that the attention pattern hook produces the correct shape (n_heads, pos, pos)."""
253253
model_name = "gpt2" # Use a smaller model for testing
254254
bridge = TransformerBridge.boot_transformers(
255255
model_name,
@@ -289,15 +289,14 @@ def capture_pattern_hook(tensor, hook):
289289
# Get the captured pattern tensor
290290
pattern_tensor = list(captured_patterns.values())[0]
291291

292-
# Verify the shape is (batch, n_heads, pos, pos)
292+
# Verify the shape is (n_heads, pos, pos) - attention patterns should not have batch dimension
293293
assert (
294-
len(pattern_tensor.shape) == 4
295-
), f"Pattern tensor should be 4D, got {len(pattern_tensor.shape)}D"
294+
len(pattern_tensor.shape) == 3
295+
), f"Pattern tensor should be 3D, got {len(pattern_tensor.shape)}D"
296296

297-
batch_dim, n_heads_dim, pos_q_dim, pos_k_dim = pattern_tensor.shape
297+
n_heads_dim, pos_q_dim, pos_k_dim = pattern_tensor.shape
298298

299299
# Verify dimensions make sense
300-
assert batch_dim == batch_size, f"Batch dimension should be {batch_size}, got {batch_dim}"
301300
assert (
302301
n_heads_dim == bridge.cfg.n_heads
303302
), f"Heads dimension should be {bridge.cfg.n_heads}, got {n_heads_dim}"

transformer_lens/model_bridge/bridge.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def run_with_hooks(
905905
return_type: Optional[str] = "logits",
906906
names_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
907907
stop_at_layer: Optional[int] = None,
908+
remove_batch_dim: bool = False,
908909
**kwargs,
909910
) -> Any:
910911
"""Run the model with specified forward and backward hooks.
@@ -918,6 +919,7 @@ def run_with_hooks(
918919
return_type: What to return ("logits", "loss", etc.)
919920
names_filter: Filter for hook names (not used directly, for compatibility)
920921
stop_at_layer: Layer to stop at (not yet fully implemented)
922+
remove_batch_dim: Whether to remove batch dimension from hook inputs (only works for batch_size==1)
921923
**kwargs: Additional arguments
922924
923925
Returns:
@@ -958,6 +960,24 @@ def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool
958960
aliases = collect_aliases_recursive(self)
959961

960962
for hook_name_or_filter, hook_fn in hooks:
963+
# Wrap the hook function to handle remove_batch_dim if needed
964+
if remove_batch_dim:
965+
original_hook_fn = hook_fn
966+
967+
def wrapped_hook_fn(tensor, hook):
968+
# Remove batch dimension if it's size 1
969+
if tensor.shape[0] == 1:
970+
tensor_no_batch = tensor.squeeze(0)
971+
result = original_hook_fn(tensor_no_batch, hook)
972+
# Add batch dimension back if result doesn't have it
973+
if result.dim() == tensor_no_batch.dim():
974+
result = result.unsqueeze(0)
975+
return result
976+
else:
977+
return original_hook_fn(tensor, hook)
978+
979+
hook_fn = wrapped_hook_fn
980+
961981
if isinstance(hook_name_or_filter, str):
962982
# Direct hook name - check for aliases first
963983
hook_dict = self.hook_dict

transformer_lens/model_bridge/generalized_components/attention.py

Lines changed: 210 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,56 @@
66
from typing import Any, Dict, Optional, Tuple
77

88
import torch
9+
import torch.nn.functional as F
910

1011
from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import (
1112
AttentionAutoConversion,
1213
)
1314
from transformer_lens.conversion_utils.conversion_steps.base_hook_conversion import (
1415
BaseHookConversion,
1516
)
16-
from transformer_lens.conversion_utils.conversion_steps.rearrange_hook_conversion import (
17-
RearrangeHookConversion,
18-
)
1917
from transformer_lens.hook_points import HookPoint
2018
from transformer_lens.model_bridge.generalized_components.base import (
2119
GeneralizedComponent,
2220
)
2321

2422

23+
class AttentionPatternConversion(BaseHookConversion):
24+
"""Custom conversion rule for attention patterns that always removes batch dimension."""
25+
26+
def handle_conversion(self, tensor: torch.Tensor, *args) -> torch.Tensor:
27+
"""Convert attention pattern tensor to standard shape [n_heads, pos, pos].
28+
29+
Args:
30+
tensor: Input tensor with shape [batch, n_heads, pos, pos] or [n_heads, pos, pos]
31+
*args: Additional context arguments (ignored)
32+
33+
Returns:
34+
Tensor with shape [n_heads, pos, pos]
35+
"""
36+
if tensor.dim() == 4:
37+
# Remove batch dimension if present
38+
return tensor.squeeze(0)
39+
elif tensor.dim() == 3:
40+
# Already in correct shape
41+
return tensor
42+
else:
43+
raise ValueError(f"Unexpected tensor shape for attention pattern: {tensor.shape}")
44+
45+
2546
class AttentionBridge(GeneralizedComponent):
2647
"""Bridge component for attention layers.
2748
28-
This component wraps attention layers from different architectures and provides
29-
a standardized interface for hook registration and execution.
49+
This component handles the conversion between Hugging Face attention layers
50+
and TransformerLens attention components.
3051
"""
3152

3253
hook_aliases = {
33-
"hook_result": "hook_hidden_states",
34-
"hook_attn_scores": "o.hook_in",
54+
"hook_result": "hook_out",
3555
"hook_q": "q.hook_out",
3656
"hook_k": "k.hook_out",
3757
"hook_v": "v.hook_out",
38-
"hook_z": "o.hook_out",
58+
"hook_z": "hook_hidden_states",
3959
}
4060

4161
property_aliases = {
@@ -65,7 +85,7 @@ def __init__(
6585
submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
6686
conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used
6787
pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
68-
uses default RearrangeHookConversion to reshape to (batch, n_heads, pos, pos)
88+
uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
6989
"""
7090
# Set up conversion rule - use AttentionAutoConversion if None
7191
if conversion_rule is None:
@@ -74,8 +94,11 @@ def __init__(
7494
super().__init__(
7595
name, config=config, submodules=submodules or {}, conversion_rule=conversion_rule
7696
)
77-
self.hook_hidden_states = HookPoint()
97+
98+
# Create only the hook points that are actually used for attention processing
99+
self.hook_attn_scores = HookPoint()
78100
self.hook_pattern = HookPoint()
101+
self.hook_hidden_states = HookPoint()
79102

80103
# Apply conversion rule to attention-specific hooks
81104
self.hook_hidden_states.hook_conversion = conversion_rule
@@ -84,29 +107,196 @@ def __init__(
84107
if pattern_conversion_rule is not None:
85108
pattern_conversion = pattern_conversion_rule
86109
else:
87-
# Create default conversion rule for attention patterns - reshape to (batch, n_heads, pos, pos)
88-
# This assumes the input is (batch, n_heads, seq_len, seq_len) or similar
89-
pattern_conversion = RearrangeHookConversion(
90-
"batch n_heads pos_q pos_k -> batch n_heads pos_q pos_k"
91-
)
110+
# Use custom conversion rule that always removes batch dimension
111+
pattern_conversion = AttentionPatternConversion()
92112

93113
self.hook_pattern.hook_conversion = pattern_conversion
94114

115+
# Store intermediate values for pattern creation
116+
self._attn_scores = None
117+
self._pattern = None
118+
95119
def _process_output(self, output: Any) -> Any:
96120
"""Process the output from the original component.
97121
122+
This method intercepts the output to create attention patterns
123+
the same way as the old implementation.
124+
98125
Args:
99126
output: Raw output from the original component
100127
101128
Returns:
102129
Processed output with hooks applied
103130
"""
131+
# Extract attention scores from the output
132+
attn_scores = self._extract_attention_scores(output)
133+
134+
if attn_scores is not None:
135+
# Create attention pattern the same way as old implementation
136+
attn_scores = self.hook_attn_scores(attn_scores)
137+
pattern = F.softmax(attn_scores, dim=-1)
138+
if not isinstance(pattern, torch.Tensor):
139+
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
140+
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
141+
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
142+
143+
# Store the pattern for potential use in result calculation
144+
self._pattern = pattern
145+
146+
# Apply the pattern to the output if needed
147+
output = self._apply_pattern_to_output(output, pattern)
148+
149+
return output
150+
151+
def _extract_attention_scores(self, output: Any) -> Optional[torch.Tensor]:
152+
"""Extract attention scores from the output.
153+
154+
Args:
155+
output: Output from the original component
156+
157+
Returns:
158+
Attention scores tensor or None if not found
159+
"""
104160
if isinstance(output, tuple):
105-
return self._process_tuple_output(output)
161+
# Look for attention scores in tuple output
162+
for element in output:
163+
if isinstance(element, torch.Tensor) and element.dim() == 4:
164+
# Assume 4D tensor is attention scores [batch, heads, query_pos, key_pos]
165+
return element
106166
elif isinstance(output, dict):
107-
return self._process_dict_output(output)
167+
# Look for attention scores in dict output
168+
for key in ["attentions", "attention_weights", "attention_scores"]:
169+
if key in output and isinstance(output[key], torch.Tensor):
170+
return output[key]
171+
172+
return None
173+
174+
def _apply_pattern_to_output(self, output: Any, pattern: torch.Tensor) -> Any:
175+
"""Apply the attention pattern to the output.
176+
177+
This method simulates how the old implementation uses the pattern
178+
to calculate the final output.
179+
180+
Args:
181+
output: Original output from the component
182+
pattern: Attention pattern tensor
183+
184+
Returns:
185+
Modified output with pattern applied
186+
"""
187+
if isinstance(output, tuple):
188+
return self._apply_pattern_to_tuple_output(output, pattern)
189+
elif isinstance(output, dict):
190+
return self._apply_pattern_to_dict_output(output, pattern)
108191
else:
109-
return self._process_single_output(output)
192+
return self._apply_pattern_to_single_output(output, pattern)
193+
194+
def _apply_pattern_to_tuple_output(
195+
self, output: Tuple[Any, ...], pattern: torch.Tensor
196+
) -> Tuple[Any, ...]:
197+
"""Apply pattern to tuple output.
198+
199+
Args:
200+
output: Tuple output from attention
201+
pattern: Attention pattern tensor
202+
203+
Returns:
204+
Processed tuple with pattern applied
205+
"""
206+
processed_output = []
207+
208+
for i, element in enumerate(output):
209+
if i == 0: # First element is typically hidden states
210+
if element is not None:
211+
element = self._apply_hook_preserving_structure(
212+
element, self.hook_hidden_states
213+
)
214+
# Apply the pattern to the hidden states
215+
element = self._apply_pattern_to_hidden_states(element, pattern)
216+
elif i == 1 or i == 2: # Attention weights indices
217+
if isinstance(element, torch.Tensor):
218+
# Replace with our computed pattern
219+
element = pattern
220+
processed_output.append(element)
221+
222+
# Apply the main hook_out to the first element (hidden states) if it exists
223+
if len(processed_output) > 0 and processed_output[0] is not None:
224+
processed_output[0] = self._apply_hook_preserving_structure(
225+
processed_output[0], self.hook_out
226+
)
227+
228+
return tuple(processed_output)
229+
230+
def _apply_pattern_to_dict_output(
231+
self, output: Dict[str, Any], pattern: torch.Tensor
232+
) -> Dict[str, Any]:
233+
"""Apply pattern to dictionary output.
234+
235+
Args:
236+
output: Dictionary output from attention
237+
pattern: Attention pattern tensor
238+
239+
Returns:
240+
Processed dictionary with pattern applied
241+
"""
242+
processed_output = {}
243+
244+
for key, value in output.items():
245+
if key in ["last_hidden_state", "hidden_states"] and value is not None:
246+
value = self._apply_hook_preserving_structure(value, self.hook_hidden_states)
247+
# Apply the pattern to the hidden states
248+
value = self._apply_pattern_to_hidden_states(value, pattern)
249+
elif key in ["attentions", "attention_weights"] and value is not None:
250+
# Replace with our computed pattern
251+
value = pattern
252+
processed_output[key] = value
253+
254+
# Apply hook_hidden_states and hook_out to the main output (usually hidden_states)
255+
main_key = next((k for k in output.keys() if "hidden" in k.lower()), None)
256+
if main_key and main_key in processed_output:
257+
processed_output[main_key] = self._apply_hook_preserving_structure(
258+
processed_output[main_key], self.hook_out
259+
)
260+
261+
return processed_output
262+
263+
def _apply_pattern_to_single_output(
264+
self, output: torch.Tensor, pattern: torch.Tensor
265+
) -> torch.Tensor:
266+
"""Apply pattern to single tensor output.
267+
268+
Args:
269+
output: Single tensor output from attention
270+
pattern: Attention pattern tensor
271+
272+
Returns:
273+
Processed tensor with pattern applied
274+
"""
275+
# Apply hooks for single tensor output
276+
output = self._apply_hook_preserving_structure(output, self.hook_hidden_states)
277+
# Apply the pattern to the output
278+
output = self._apply_pattern_to_hidden_states(output, pattern)
279+
output = self._apply_hook_preserving_structure(output, self.hook_out)
280+
return output
281+
282+
def _apply_pattern_to_hidden_states(
283+
self, hidden_states: torch.Tensor, pattern: torch.Tensor
284+
) -> torch.Tensor:
285+
"""Apply attention pattern to hidden states.
286+
287+
This simulates the old implementation's calculate_z_scores method.
288+
289+
Args:
290+
hidden_states: Hidden states tensor
291+
pattern: Attention pattern tensor
292+
293+
Returns:
294+
Modified hidden states with pattern applied
295+
"""
296+
# This is a simplified version - in the real implementation,
297+
# we would need to extract V from the original component and apply
298+
# the pattern properly. For now, we just apply the pattern as a hook.
299+
return self.hook_hidden_states(hidden_states)
110300

111301
def _process_tuple_output(self, output: Tuple[Any, ...]) -> Tuple[Any, ...]:
112302
"""Process tuple output from attention layer.
@@ -202,8 +392,8 @@ def _apply_hook_preserving_structure(self, element: Any, hook_fn) -> Any:
202392
if isinstance(element[0], torch.Tensor):
203393
processed_elements[0] = hook_fn(element[0])
204394
return tuple(processed_elements)
205-
# For other types, return as-is
206-
return element
395+
else:
396+
return element
207397

208398
def forward(self, *args: Any, **kwargs: Any) -> Any:
209399
"""Forward pass through the attention layer.
@@ -237,8 +427,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
237427
# Process output
238428
output = self._process_output(output)
239429

240-
# Update hook outputs for debugging/inspection
241-
242430
return output
243431

244432
def get_attention_weights(self) -> Optional[torch.Tensor]:

0 commit comments

Comments
 (0)