Skip to content

Commit 3a6f596

Browse files
committed
Improve hook compatibility
1 parent 7e5a5d9 commit 3a6f596

File tree

2 files changed

+20
-51
lines changed

2 files changed

+20
-51
lines changed

transformer_lens/model_bridge/generalized_components/attention.py

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Dict, Optional, Tuple
77

88
import torch
9-
import torch.nn.functional as F
109

1110
from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import (
1211
AttentionAutoConversion,
@@ -20,29 +19,6 @@
2019
)
2120

2221

23-
class AttentionPatternConversion(BaseHookConversion):
24-
"""Custom conversion rule for attention patterns that always removes batch dimension."""
25-
26-
def handle_conversion(self, tensor: torch.Tensor, *args) -> torch.Tensor:
27-
"""Convert attention pattern tensor to standard shape [n_heads, pos, pos].
28-
29-
Args:
30-
tensor: Input tensor with shape [batch, n_heads, pos, pos] or [n_heads, pos, pos]
31-
*args: Additional context arguments (ignored)
32-
33-
Returns:
34-
Tensor with shape [n_heads, pos, pos]
35-
"""
36-
if tensor.dim() == 4:
37-
# Remove batch dimension if present
38-
return tensor.squeeze(0)
39-
elif tensor.dim() == 3:
40-
# Already in correct shape
41-
return tensor
42-
else:
43-
raise ValueError(f"Unexpected tensor shape for attention pattern: {tensor.shape}")
44-
45-
4622
class AttentionBridge(GeneralizedComponent):
4723
"""Bridge component for attention layers.
4824
@@ -55,7 +31,7 @@ class AttentionBridge(GeneralizedComponent):
5531
"hook_q": "q.hook_out",
5632
"hook_k": "k.hook_out",
5733
"hook_v": "v.hook_out",
58-
"hook_z": "hook_hidden_states",
34+
"hook_z": "o.hook_in",
5935
}
6036

6137
property_aliases = {
@@ -103,14 +79,9 @@ def __init__(
10379
# Apply conversion rule to attention-specific hooks
10480
self.hook_hidden_states.hook_conversion = conversion_rule
10581

106-
# Set up pattern conversion rule - use provided rule or create default
82+
# Set up pattern conversion rule if provided
10783
if pattern_conversion_rule is not None:
108-
pattern_conversion = pattern_conversion_rule
109-
else:
110-
# Use custom conversion rule that always removes batch dimension
111-
pattern_conversion = AttentionPatternConversion()
112-
113-
self.hook_pattern.hook_conversion = pattern_conversion
84+
self.hook_pattern.hook_conversion = pattern_conversion_rule
11485

11586
# Store intermediate values for pattern creation
11687
self._attn_scores = None
@@ -129,24 +100,22 @@ def _process_output(self, output: Any) -> Any:
129100
Processed output with hooks applied
130101
"""
131102
# Extract attention scores from the output
132-
attn_scores = self._extract_attention_scores(output)
103+
attn_pattern = self._extract_attention_pattern(output)
104+
105+
if attn_pattern is not None:
106+
if not isinstance(attn_pattern, torch.Tensor):
107+
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(attn_pattern)}")
133108

134-
if attn_scores is not None:
135109
# Create attention pattern the same way as old implementation
136-
attn_scores = self.hook_attn_scores(attn_scores)
137-
pattern = F.softmax(attn_scores, dim=-1)
138-
if not isinstance(pattern, torch.Tensor):
139-
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
140-
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
141-
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
110+
attn_pattern = self.hook_pattern(attn_pattern)
142111

143112
# Store the pattern for potential use in result calculation
144-
self._pattern = pattern
113+
self._pattern = attn_pattern
145114

146115
# Apply the pattern to the output if needed
147-
output = self._apply_pattern_to_output(output, pattern)
116+
output = self._apply_pattern_to_output(output, attn_pattern)
148117
else:
149-
# If no attention scores found, still apply hooks to the output
118+
# If no attention pattern found, still apply hooks to the output
150119
if isinstance(output, tuple):
151120
output = self._process_tuple_output(output)
152121
elif isinstance(output, dict):
@@ -159,24 +128,24 @@ def _process_output(self, output: Any) -> Any:
159128

160129
return output
161130

162-
def _extract_attention_scores(self, output: Any) -> Optional[torch.Tensor]:
163-
"""Extract attention scores from the output.
131+
def _extract_attention_pattern(self, output: Any) -> Optional[torch.Tensor]:
132+
"""Extract attention pattern from the output.
164133
165134
Args:
166135
output: Output from the original component
167136
168137
Returns:
169-
Attention scores tensor or None if not found
138+
Attention pattern tensor or None if not found
170139
"""
171140
if isinstance(output, tuple):
172-
# Look for attention scores in tuple output
141+
# Look for attention pattern in tuple output
173142
for element in output:
174143
if isinstance(element, torch.Tensor) and element.dim() == 4:
175-
# Assume 4D tensor is attention scores [batch, heads, query_pos, key_pos]
144+
# Assume 4D tensor is attention pattern [batch, heads, query_pos, key_pos]
176145
return element
177146
elif isinstance(output, dict):
178-
# Look for attention scores in dict output
179-
for key in ["attentions", "attention_weights", "attention_scores"]:
147+
# Look for attention pattern in dict output
148+
for key in ["attentions", "attention_weights", "attention_scores", "attn_weights"]:
180149
if key in output and isinstance(output[key], torch.Tensor):
181150
return output[key]
182151

transformer_lens/model_bridge/generalized_components/block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BlockBridge(GeneralizedComponent):
2323

2424
hook_aliases = {
2525
"hook_resid_pre": "hook_in",
26-
"hook_resid_mid": "attn.hook_out",
26+
"hook_resid_mid": "ln2.hook_in",
2727
"hook_resid_post": "hook_out",
2828
"hook_attn_in": "attn.hook_in",
2929
"hook_attn_out": "attn.hook_out",

0 commit comments

Comments
 (0)