Skip to content

Commit deebbc6

Browse files
committed
Fix negative index handling in MultiHeadAttention attention_axes
1 parent 2a5bb21 commit deebbc6

File tree

5 files changed

+48
-3
lines changed

5 files changed

+48
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__
66
**/.vscode test/**
77
**/.vscode-smoke/**
88
**/.venv*/
9+
venv
910
bin/**
1011
build/**
1112
obj/**
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
3+
import keras
4+
5+
6+
def test_attention_axes_negative_indexing_matches_positive():
7+
x = np.random.normal(size=(2, 3, 8, 4))
8+
9+
mha_pos = keras.layers.MultiHeadAttention(
10+
num_heads=2, key_dim=4, attention_axes=2
11+
)
12+
mha_neg = keras.layers.MultiHeadAttention(
13+
num_heads=2, key_dim=4, attention_axes=-2
14+
)
15+
16+
_ = mha_pos(x, x)
17+
_ = mha_neg(x, x)
18+
19+
mha_neg.set_weights(mha_pos.get_weights())
20+
21+
z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)
22+
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)
23+
24+
assert z_pos.shape == z_neg.shape
25+
assert a_pos.shape == a_neg.shape
26+
27+
np.testing.assert_allclose(z_pos, z_neg, rtol=1e-5, atol=1e-5)
28+
np.testing.assert_allclose(a_pos, a_neg, rtol=1e-5, atol=1e-5)

integration_tests/test_save_img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ def test_save_jpg(tmp_path, shape, name):
2424
# Check that the image was saved correctly and converted to RGB if needed.
2525
loaded_img = load_img(path)
2626
loaded_array = img_to_array(loaded_img)
27-
assert loaded_array.shape == (50, 50, 3)
27+
assert loaded_array.shape == (50, 50, 3)

keras/src/layers/attention/multi_head_attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,17 @@ def _build_attention(self, rank):
378378
if self._attention_axes is None:
379379
self._attention_axes = tuple(range(1, rank - 2))
380380
else:
381-
self._attention_axes = tuple(self._attention_axes)
381+
# Normalize negative indices relative to INPUT rank (rank - 1)
382+
input_rank = rank - 1
383+
normalized_axes = []
384+
for ax in self._attention_axes:
385+
if ax < 0:
386+
# Normalize relative to input rank
387+
normalized_ax = input_rank + ax
388+
else:
389+
normalized_ax = ax
390+
normalized_axes.append(normalized_ax)
391+
self._attention_axes = tuple(normalized_axes)
382392
(
383393
self._dot_product_equation,
384394
self._combine_equation,
@@ -760,6 +770,12 @@ def _build_attention_equation(rank, attn_axes):
760770
Returns:
761771
Einsum equations.
762772
"""
773+
# Normalize negative indices to positive indices
774+
if isinstance(attn_axes, (list, tuple)):
775+
attn_axes = tuple(ax % rank if ax < 0 else ax for ax in attn_axes)
776+
else:
777+
attn_axes = (attn_axes % rank if attn_axes < 0 else attn_axes,)
778+
763779
target_notation = ""
764780
for i in range(rank):
765781
target_notation += _index_to_einsum_variable(i)

keras/src/utils/image_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
179179
if file_format is not None and file_format.lower() == "jpg":
180180
file_format = "jpeg"
181181
img = array_to_img(x, data_format=data_format, scale=scale)
182-
if img.mode == "RGBA" and file_format == "jpeg":
182+
if img.mode == "RGBA" and file_format == "jpeg":
183183
warnings.warn(
184184
"The JPEG format does not support RGBA images, converting to RGB."
185185
)

0 commit comments

Comments
 (0)