66from typing import Any , Dict , Optional , Tuple
77
88import torch
9+ import torch .nn .functional as F
910
1011from transformer_lens .conversion_utils .conversion_steps .attention_auto_conversion import (
1112 AttentionAutoConversion ,
1213)
1314from transformer_lens .conversion_utils .conversion_steps .base_hook_conversion import (
1415 BaseHookConversion ,
1516)
16- from transformer_lens .conversion_utils .conversion_steps .rearrange_hook_conversion import (
17- RearrangeHookConversion ,
18- )
1917from transformer_lens .hook_points import HookPoint
2018from transformer_lens .model_bridge .generalized_components .base import (
2119 GeneralizedComponent ,
2220)
2321
2422
23+ class AttentionPatternConversion (BaseHookConversion ):
24+ """Custom conversion rule for attention patterns that always removes batch dimension."""
25+
26+ def handle_conversion (self , tensor : torch .Tensor , * args ) -> torch .Tensor :
27+ """Convert attention pattern tensor to standard shape [n_heads, pos, pos].
28+
29+ Args:
30+ tensor: Input tensor with shape [batch, n_heads, pos, pos] or [n_heads, pos, pos]
31+ *args: Additional context arguments (ignored)
32+
33+ Returns:
34+ Tensor with shape [n_heads, pos, pos]
35+ """
36+ if tensor .dim () == 4 :
37+ # Remove batch dimension if present
38+ return tensor .squeeze (0 )
39+ elif tensor .dim () == 3 :
40+ # Already in correct shape
41+ return tensor
42+ else :
43+ raise ValueError (f"Unexpected tensor shape for attention pattern: { tensor .shape } " )
44+
45+
2546class AttentionBridge (GeneralizedComponent ):
2647 """Bridge component for attention layers.
2748
28- This component wraps attention layers from different architectures and provides
29- a standardized interface for hook registration and execution .
49+ This component handles the conversion between Hugging Face attention layers
50+ and TransformerLens attention components .
3051 """
3152
3253 hook_aliases = {
33- "hook_result" : "hook_hidden_states" ,
34- "hook_attn_scores" : "o.hook_in" ,
54+ "hook_result" : "hook_out" ,
3555 "hook_q" : "q.hook_out" ,
3656 "hook_k" : "k.hook_out" ,
3757 "hook_v" : "v.hook_out" ,
38- "hook_z" : "o.hook_out " ,
58+ "hook_z" : "hook_hidden_states " ,
3959 }
4060
4161 property_aliases = {
@@ -65,7 +85,7 @@ def __init__(
6585 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
6686 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used
6787 pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
68- uses default RearrangeHookConversion to reshape to (batch, n_heads, pos, pos)
88+ uses AttentionPatternConversion to ensure [ n_heads, pos, pos] shape
6989 """
7090 # Set up conversion rule - use AttentionAutoConversion if None
7191 if conversion_rule is None :
@@ -74,8 +94,11 @@ def __init__(
7494 super ().__init__ (
7595 name , config = config , submodules = submodules or {}, conversion_rule = conversion_rule
7696 )
77- self .hook_hidden_states = HookPoint ()
97+
98+ # Create only the hook points that are actually used for attention processing
99+ self .hook_attn_scores = HookPoint ()
78100 self .hook_pattern = HookPoint ()
101+ self .hook_hidden_states = HookPoint ()
79102
80103 # Apply conversion rule to attention-specific hooks
81104 self .hook_hidden_states .hook_conversion = conversion_rule
@@ -84,29 +107,196 @@ def __init__(
84107 if pattern_conversion_rule is not None :
85108 pattern_conversion = pattern_conversion_rule
86109 else :
87- # Create default conversion rule for attention patterns - reshape to (batch, n_heads, pos, pos)
88- # This assumes the input is (batch, n_heads, seq_len, seq_len) or similar
89- pattern_conversion = RearrangeHookConversion (
90- "batch n_heads pos_q pos_k -> batch n_heads pos_q pos_k"
91- )
110+ # Use custom conversion rule that always removes batch dimension
111+ pattern_conversion = AttentionPatternConversion ()
92112
93113 self .hook_pattern .hook_conversion = pattern_conversion
94114
115+ # Store intermediate values for pattern creation
116+ self ._attn_scores = None
117+ self ._pattern = None
118+
95119 def _process_output (self , output : Any ) -> Any :
96120 """Process the output from the original component.
97121
122+ This method intercepts the output to create attention patterns
123+ the same way as the old implementation.
124+
98125 Args:
99126 output: Raw output from the original component
100127
101128 Returns:
102129 Processed output with hooks applied
103130 """
131+ # Extract attention scores from the output
132+ attn_scores = self ._extract_attention_scores (output )
133+
134+ if attn_scores is not None :
135+ # Create attention pattern the same way as old implementation
136+ attn_scores = self .hook_attn_scores (attn_scores )
137+ pattern = F .softmax (attn_scores , dim = - 1 )
138+ if not isinstance (pattern , torch .Tensor ):
139+ raise TypeError (f"Expected 'pattern' to be a Tensor, got { type (pattern )} " )
140+ pattern = torch .where (torch .isnan (pattern ), torch .zeros_like (pattern ), pattern )
141+ pattern = self .hook_pattern (pattern ) # [batch, head_index, query_pos, key_pos]
142+
143+ # Store the pattern for potential use in result calculation
144+ self ._pattern = pattern
145+
146+ # Apply the pattern to the output if needed
147+ output = self ._apply_pattern_to_output (output , pattern )
148+
149+ return output
150+
151+ def _extract_attention_scores (self , output : Any ) -> Optional [torch .Tensor ]:
152+ """Extract attention scores from the output.
153+
154+ Args:
155+ output: Output from the original component
156+
157+ Returns:
158+ Attention scores tensor or None if not found
159+ """
104160 if isinstance (output , tuple ):
105- return self ._process_tuple_output (output )
161+ # Look for attention scores in tuple output
162+ for element in output :
163+ if isinstance (element , torch .Tensor ) and element .dim () == 4 :
164+ # Assume 4D tensor is attention scores [batch, heads, query_pos, key_pos]
165+ return element
106166 elif isinstance (output , dict ):
107- return self ._process_dict_output (output )
167+ # Look for attention scores in dict output
168+ for key in ["attentions" , "attention_weights" , "attention_scores" ]:
169+ if key in output and isinstance (output [key ], torch .Tensor ):
170+ return output [key ]
171+
172+ return None
173+
174+ def _apply_pattern_to_output (self , output : Any , pattern : torch .Tensor ) -> Any :
175+ """Apply the attention pattern to the output.
176+
177+ This method simulates how the old implementation uses the pattern
178+ to calculate the final output.
179+
180+ Args:
181+ output: Original output from the component
182+ pattern: Attention pattern tensor
183+
184+ Returns:
185+ Modified output with pattern applied
186+ """
187+ if isinstance (output , tuple ):
188+ return self ._apply_pattern_to_tuple_output (output , pattern )
189+ elif isinstance (output , dict ):
190+ return self ._apply_pattern_to_dict_output (output , pattern )
108191 else :
109- return self ._process_single_output (output )
192+ return self ._apply_pattern_to_single_output (output , pattern )
193+
194+ def _apply_pattern_to_tuple_output (
195+ self , output : Tuple [Any , ...], pattern : torch .Tensor
196+ ) -> Tuple [Any , ...]:
197+ """Apply pattern to tuple output.
198+
199+ Args:
200+ output: Tuple output from attention
201+ pattern: Attention pattern tensor
202+
203+ Returns:
204+ Processed tuple with pattern applied
205+ """
206+ processed_output = []
207+
208+ for i , element in enumerate (output ):
209+ if i == 0 : # First element is typically hidden states
210+ if element is not None :
211+ element = self ._apply_hook_preserving_structure (
212+ element , self .hook_hidden_states
213+ )
214+ # Apply the pattern to the hidden states
215+ element = self ._apply_pattern_to_hidden_states (element , pattern )
216+ elif i == 1 or i == 2 : # Attention weights indices
217+ if isinstance (element , torch .Tensor ):
218+ # Replace with our computed pattern
219+ element = pattern
220+ processed_output .append (element )
221+
222+ # Apply the main hook_out to the first element (hidden states) if it exists
223+ if len (processed_output ) > 0 and processed_output [0 ] is not None :
224+ processed_output [0 ] = self ._apply_hook_preserving_structure (
225+ processed_output [0 ], self .hook_out
226+ )
227+
228+ return tuple (processed_output )
229+
230+ def _apply_pattern_to_dict_output (
231+ self , output : Dict [str , Any ], pattern : torch .Tensor
232+ ) -> Dict [str , Any ]:
233+ """Apply pattern to dictionary output.
234+
235+ Args:
236+ output: Dictionary output from attention
237+ pattern: Attention pattern tensor
238+
239+ Returns:
240+ Processed dictionary with pattern applied
241+ """
242+ processed_output = {}
243+
244+ for key , value in output .items ():
245+ if key in ["last_hidden_state" , "hidden_states" ] and value is not None :
246+ value = self ._apply_hook_preserving_structure (value , self .hook_hidden_states )
247+ # Apply the pattern to the hidden states
248+ value = self ._apply_pattern_to_hidden_states (value , pattern )
249+ elif key in ["attentions" , "attention_weights" ] and value is not None :
250+ # Replace with our computed pattern
251+ value = pattern
252+ processed_output [key ] = value
253+
254+ # Apply hook_hidden_states and hook_out to the main output (usually hidden_states)
255+ main_key = next ((k for k in output .keys () if "hidden" in k .lower ()), None )
256+ if main_key and main_key in processed_output :
257+ processed_output [main_key ] = self ._apply_hook_preserving_structure (
258+ processed_output [main_key ], self .hook_out
259+ )
260+
261+ return processed_output
262+
263+ def _apply_pattern_to_single_output (
264+ self , output : torch .Tensor , pattern : torch .Tensor
265+ ) -> torch .Tensor :
266+ """Apply pattern to single tensor output.
267+
268+ Args:
269+ output: Single tensor output from attention
270+ pattern: Attention pattern tensor
271+
272+ Returns:
273+ Processed tensor with pattern applied
274+ """
275+ # Apply hooks for single tensor output
276+ output = self ._apply_hook_preserving_structure (output , self .hook_hidden_states )
277+ # Apply the pattern to the output
278+ output = self ._apply_pattern_to_hidden_states (output , pattern )
279+ output = self ._apply_hook_preserving_structure (output , self .hook_out )
280+ return output
281+
282+ def _apply_pattern_to_hidden_states (
283+ self , hidden_states : torch .Tensor , pattern : torch .Tensor
284+ ) -> torch .Tensor :
285+ """Apply attention pattern to hidden states.
286+
287+ This simulates the old implementation's calculate_z_scores method.
288+
289+ Args:
290+ hidden_states: Hidden states tensor
291+ pattern: Attention pattern tensor
292+
293+ Returns:
294+ Modified hidden states with pattern applied
295+ """
296+ # This is a simplified version - in the real implementation,
297+ # we would need to extract V from the original component and apply
298+ # the pattern properly. For now, we just apply the pattern as a hook.
299+ return self .hook_hidden_states (hidden_states )
110300
111301 def _process_tuple_output (self , output : Tuple [Any , ...]) -> Tuple [Any , ...]:
112302 """Process tuple output from attention layer.
@@ -202,8 +392,8 @@ def _apply_hook_preserving_structure(self, element: Any, hook_fn) -> Any:
202392 if isinstance (element [0 ], torch .Tensor ):
203393 processed_elements [0 ] = hook_fn (element [0 ])
204394 return tuple (processed_elements )
205- # For other types, return as-is
206- return element
395+ else :
396+ return element
207397
208398 def forward (self , * args : Any , ** kwargs : Any ) -> Any :
209399 """Forward pass through the attention layer.
@@ -237,8 +427,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
237427 # Process output
238428 output = self ._process_output (output )
239429
240- # Update hook outputs for debugging/inspection
241-
242430 return output
243431
244432 def get_attention_weights (self ) -> Optional [torch .Tensor ]:
0 commit comments