Skip to content

Commit e8fda51

Browse files
authored
registered hook correctly (#1051)
* registered hook correctly * fixed typing
1 parent ee9b44b commit e8fda51

File tree

1 file changed

+40
-29
lines changed
  • transformer_lens/model_bridge/generalized_components

1 file changed

+40
-29
lines changed

transformer_lens/model_bridge/generalized_components/attention.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _process_output(self, output: Any) -> Any:
120120
"""Process the output from the original component.
121121
122122
This method intercepts the output to create attention patterns
123-
the same way as the old implementation.
123+
the same way as the old implementation and applies hook_out.
124124
125125
Args:
126126
output: Raw output from the original component
@@ -145,6 +145,17 @@ def _process_output(self, output: Any) -> Any:
145145

146146
# Apply the pattern to the output if needed
147147
output = self._apply_pattern_to_output(output, pattern)
148+
else:
149+
# If no attention scores found, still apply hooks to the output
150+
if isinstance(output, tuple):
151+
output = self._process_tuple_output(output)
152+
elif isinstance(output, dict):
153+
output = self._process_dict_output(output)
154+
else:
155+
output = self._process_single_output(output)
156+
157+
# Always apply hook_out to the main output
158+
output = self._apply_hook_out_to_output(output)
148159

149160
return output
150161

@@ -219,12 +230,6 @@ def _apply_pattern_to_tuple_output(
219230
element = pattern
220231
processed_output.append(element)
221232

222-
# Apply the main hook_out to the first element (hidden states) if it exists
223-
if len(processed_output) > 0 and processed_output[0] is not None:
224-
processed_output[0] = self._apply_hook_preserving_structure(
225-
processed_output[0], self.hook_out
226-
)
227-
228233
return tuple(processed_output)
229234

230235
def _apply_pattern_to_dict_output(
@@ -251,13 +256,6 @@ def _apply_pattern_to_dict_output(
251256
value = pattern
252257
processed_output[key] = value
253258

254-
# Apply hook_hidden_states and hook_out to the main output (usually hidden_states)
255-
main_key = next((k for k in output.keys() if "hidden" in k.lower()), None)
256-
if main_key and main_key in processed_output:
257-
processed_output[main_key] = self._apply_hook_preserving_structure(
258-
processed_output[main_key], self.hook_out
259-
)
260-
261259
return processed_output
262260

263261
def _apply_pattern_to_single_output(
@@ -276,7 +274,6 @@ def _apply_pattern_to_single_output(
276274
output = self._apply_hook_preserving_structure(output, self.hook_hidden_states)
277275
# Apply the pattern to the output
278276
output = self._apply_pattern_to_hidden_states(output, pattern)
279-
output = self._apply_hook_preserving_structure(output, self.hook_out)
280277
return output
281278

282279
def _apply_pattern_to_hidden_states(
@@ -325,12 +322,6 @@ def _process_tuple_output(self, output: Tuple[Any, ...]) -> Tuple[Any, ...]:
325322
element = self._apply_hook_preserving_structure(element, self.hook_pattern)
326323
processed_output.append(element)
327324

328-
# Apply the main hook_out to the first element (hidden states) if it exists
329-
if len(processed_output) > 0 and processed_output[0] is not None:
330-
processed_output[0] = self._apply_hook_preserving_structure(
331-
processed_output[0], self.hook_out
332-
)
333-
334325
return tuple(processed_output)
335326

336327
def _process_dict_output(self, output: Dict[str, Any]) -> Dict[str, Any]:
@@ -351,13 +342,6 @@ def _process_dict_output(self, output: Dict[str, Any]) -> Dict[str, Any]:
351342
value = self._apply_hook_preserving_structure(value, self.hook_pattern)
352343
processed_output[key] = value
353344

354-
# Apply hook_hidden_states and hook_out to the main output (usually hidden_states)
355-
main_key = next((k for k in output.keys() if "hidden" in k.lower()), None)
356-
if main_key and main_key in processed_output:
357-
processed_output[main_key] = self._apply_hook_preserving_structure(
358-
processed_output[main_key], self.hook_out
359-
)
360-
361345
return processed_output
362346

363347
def _process_single_output(self, output: torch.Tensor) -> torch.Tensor:
@@ -371,7 +355,6 @@ def _process_single_output(self, output: torch.Tensor) -> torch.Tensor:
371355
"""
372356
# Apply hooks for single tensor output
373357
output = self._apply_hook_preserving_structure(output, self.hook_hidden_states)
374-
output = self._apply_hook_preserving_structure(output, self.hook_out)
375358
return output
376359

377360
def _apply_hook_preserving_structure(self, element: Any, hook_fn) -> Any:
@@ -395,6 +378,34 @@ def _apply_hook_preserving_structure(self, element: Any, hook_fn) -> Any:
395378
else:
396379
return element
397380

381+
def _apply_hook_out_to_output(self, output: Any) -> Any:
382+
"""Apply hook_out to the main output tensor.
383+
384+
Args:
385+
output: The output to process (can be tensor, tuple, or dict)
386+
387+
Returns:
388+
The output with hook_out applied to the main tensor
389+
"""
390+
if isinstance(output, torch.Tensor):
391+
return self.hook_out(output)
392+
elif isinstance(output, tuple) and len(output) > 0:
393+
# Apply hook_out to the first element (typically hidden states)
394+
processed_tuple = list(output)
395+
if isinstance(output[0], torch.Tensor):
396+
processed_tuple[0] = self.hook_out(output[0])
397+
return tuple(processed_tuple)
398+
elif isinstance(output, dict):
399+
# Apply hook_out to the main hidden states in dictionary
400+
processed_dict = output.copy()
401+
for key in ["last_hidden_state", "hidden_states"]:
402+
if key in processed_dict and isinstance(processed_dict[key], torch.Tensor):
403+
processed_dict[key] = self.hook_out(processed_dict[key])
404+
break # Only apply to the first found key
405+
return processed_dict
406+
else:
407+
return output
408+
398409
def forward(self, *args: Any, **kwargs: Any) -> Any:
399410
"""Forward pass through the attention layer.
400411

0 commit comments

Comments
 (0)