Skip to content

Commit 04b751f

Browse files
authored
Fix attention vizualizer (#40285)
* make visualizer rely on create causal mask * format * fixup * fixup * read token * read token, duh * what is up with that token * small tests? * adjust * try with flush * normalize for ANSI * buffer shenanigans
1 parent 1e1db12 commit 04b751f

File tree

2 files changed

+144
-6
lines changed

2 files changed

+144
-6
lines changed

src/transformers/utils/attention_visualizer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import requests
1717
from PIL import Image
1818

19+
from ..masking_utils import create_causal_mask
1920
from ..models.auto.auto_factory import _get_model_class
2021
from ..models.auto.configuration_auto import AutoConfig
2122
from ..models.auto.modeling_auto import MODEL_FOR_PRETRAINING_MAPPING, MODEL_MAPPING
@@ -207,13 +208,23 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
207208

208209
model.config._attn_implementation = "eager"
209210
model.train()
210-
attention_mask = ~model._update_causal_mask(
211+
212+
batch_size, seq_length = attention_mask.shape
213+
input_embeds = torch.zeros((batch_size, seq_length, model.config.hidden_size), dtype=self.model.dtype)
214+
cache_position = torch.arange(seq_length)
215+
216+
causal_mask = create_causal_mask(
217+
config=model.config,
218+
input_embeds=input_embeds,
211219
attention_mask=attention_mask,
212-
input_tensor=attention_mask.to(self.model.dtype),
213-
cache_position=torch.arange(attention_mask.shape[1]),
220+
cache_position=cache_position,
214221
past_key_values=None,
215-
**kwargs,
216-
).bool()
222+
)
223+
224+
if causal_mask is not None:
225+
attention_mask = ~causal_mask.bool()
226+
else:
227+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length)
217228
top_bottom_border = "##" * (
218229
len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4
219230
) # Box width adjusted to text length
@@ -225,7 +236,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
225236
len(top_bottom_border)
226237
)
227238
+ " "
228-
+ side_border
239+
+ side_border,
229240
)
230241
print(f"{top_bottom_border}")
231242
f_string = generate_attention_matrix_from_mask(
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2025 The HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import builtins
16+
import io
17+
import re
18+
import unittest
19+
20+
from transformers.testing_utils import require_read_token, require_torch
21+
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
22+
23+
24+
ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
25+
26+
27+
def _normalize(s: str) -> str:
28+
# drop ANSI (colors may be disabled on CI), normalize line endings,
29+
# and strip trailing spaces without touching alignment inside lines
30+
s = ANSI_RE.sub("", s)
31+
s = s.replace("\r\n", "\n").replace("\r", "\n")
32+
return "\n".join(line.rstrip() for line in s.split("\n")).strip()
33+
34+
35+
@require_torch
36+
class AttentionMaskVisualizerTester(unittest.TestCase):
37+
"""Test suite for AttentionMaskVisualizer"""
38+
39+
@require_read_token
40+
def test_paligemma_multimodal_visualization(self):
41+
"""Test AttentionMaskVisualizer with PaliGemma multimodal model"""
42+
model_name = "hf-internal-testing/namespace_google_repo_name_paligemma-3b-pt-224"
43+
input_text = "<img> What is in this image?"
44+
45+
buf = io.StringIO()
46+
orig_print = builtins.print
47+
48+
def _print(*args, **kwargs):
49+
kwargs.setdefault("file", buf)
50+
orig_print(*args, **kwargs)
51+
52+
try:
53+
builtins.print = _print
54+
visualizer = AttentionMaskVisualizer(model_name)
55+
visualizer(input_text)
56+
finally:
57+
builtins.print = orig_print
58+
output = buf.getvalue()
59+
60+
expected_output = """
61+
##########################################################################################################################################################################################################################################
62+
## Attention visualization for \033[1mpaligemma:hf-internal-testing/namespace_google_repo_name_paligemma-3b-pt-224\033[0m PaliGemmaModel ##
63+
##########################################################################################################################################################################################################################################
64+
\033[92m■\033[0m: i == j (diagonal) \033[93m■\033[0m: token_type_ids
65+
Attention Matrix
66+
67+
68+
\033[93m'<image>'\033[0m: 0 \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
69+
\033[93m'<image>'\033[0m: 1 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
70+
\033[93m'<image>'\033[0m: 2 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
71+
\033[93m'<image>'\033[0m: 3 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
72+
\033[93m'<image>'\033[0m: 4 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
73+
'<bos>' : 5 ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
74+
'▁What' : 6 ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
75+
'▁is' : 7 ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
76+
'▁in' : 8 ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ |
77+
'▁this' : 9 ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ |
78+
'▁image' : 10 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ |
79+
'?' : 11 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ |
80+
'\\n' : 12 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ |
81+
'<eos>' : 13 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m |
82+
##########################################################################################################################################################################################################################################
83+
""" # noqa
84+
85+
self.assertEqual(_normalize(output), _normalize(expected_output))
86+
87+
@require_read_token
88+
def test_llama_text_only_visualization(self):
89+
"""Test AttentionMaskVisualizer with Llama text-only model"""
90+
model_name = "hf-internal-testing/namespace_meta-llama_repo_name_Llama-2-7b-hf"
91+
input_text = "Plants create energy through a process known as"
92+
93+
buf = io.StringIO()
94+
orig_print = builtins.print
95+
96+
def _print(*args, **kwargs):
97+
kwargs.setdefault("file", buf)
98+
orig_print(*args, **kwargs)
99+
100+
try:
101+
builtins.print = _print
102+
visualizer = AttentionMaskVisualizer(model_name)
103+
visualizer(input_text)
104+
finally:
105+
builtins.print = orig_print
106+
output = buf.getvalue()
107+
108+
expected_output = """
109+
##########################################################################################################################################################################################################
110+
## Attention visualization for \033[1mllama:hf-internal-testing/namespace_meta-llama_repo_name_Llama-2-7b-hf\033[0m LlamaModel ##
111+
##########################################################################################################################################################################################################
112+
\033[92m■\033[0m: i == j (diagonal) \033[93m■\033[0m: token_type_ids
113+
Attention Matrix
114+
115+
'▁Pl' : 0 \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
116+
'ants' : 1 ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
117+
'▁create' : 2 ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
118+
'▁energy' : 3 ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ |
119+
'▁through': 4 ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ |
120+
'▁a' : 5 ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ |
121+
'▁process': 6 ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ |
122+
'▁known' : 7 ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ |
123+
'▁as' : 8 ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m |
124+
##########################################################################################################################################################################################################
125+
""" # noqa
126+
127+
self.assertEqual(_normalize(output), _normalize(expected_output))

0 commit comments

Comments
 (0)