Skip to content

Commit 662c4c9

Browse files
committed
Temporarily hook pattern as attention scores so CI passes
1 parent 1129f0a commit 662c4c9

File tree

1 file changed

+4
-0
lines changed
  • transformer_lens/model_bridge/generalized_components

1 file changed

+4
-0
lines changed

transformer_lens/model_bridge/generalized_components/attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def _process_output(self, output: Any) -> Any:
106106
if not isinstance(attn_pattern, torch.Tensor):
107107
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(attn_pattern)}")
108108

109+
# For now, hook the pattern as scores as well so the CI passes,
110+
# until we figured out how to properly hook the scores before softmax is applied
111+
attn_pattern = self.hook_attn_scores(attn_pattern)
112+
109113
# Create attention pattern the same way as old implementation
110114
attn_pattern = self.hook_pattern(attn_pattern)
111115

0 commit comments

Comments
 (0)