Skip to content

Commit e704b28

Browse files
committed
Fix negative index handling in MultiHeadAttention attention_axes
1 parent 2a5bb21 commit e704b28

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
3+
import keras
4+
5+
6+
def test_attention_axes_negative_indexing_matches_positive():
7+
x = np.random.normal(size=(2, 3, 8, 4))
8+
9+
mha_pos = keras.layers.MultiHeadAttention(
10+
num_heads=2, key_dim=4, attention_axes=2
11+
)
12+
mha_neg = keras.layers.MultiHeadAttention(
13+
num_heads=2, key_dim=4, attention_axes=-2
14+
)
15+
16+
_ = mha_pos(x, x)
17+
_ = mha_neg(x, x)
18+
19+
mha_neg.set_weights(mha_pos.get_weights())
20+
21+
z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)
22+
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)
23+
24+
assert z_pos.shape == z_neg.shape
25+
assert a_pos.shape == a_neg.shape
26+
27+
np.testing.assert_allclose(z_pos, z_neg, rtol=1e-5, atol=1e-5)
28+
np.testing.assert_allclose(a_pos, a_neg, rtol=1e-5, atol=1e-5)

keras/src/layers/attention/multi_head_attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,17 @@ def _build_attention(self, rank):
378378
if self._attention_axes is None:
379379
self._attention_axes = tuple(range(1, rank - 2))
380380
else:
381-
self._attention_axes = tuple(self._attention_axes)
381+
# Normalize negative indices relative to INPUT rank (rank - 1)
382+
input_rank = rank - 1
383+
normalized_axes = []
384+
for ax in self._attention_axes:
385+
if ax < 0:
386+
# Normalize relative to input rank
387+
normalized_ax = input_rank + ax
388+
else:
389+
normalized_ax = ax
390+
normalized_axes.append(normalized_ax)
391+
self._attention_axes = tuple(normalized_axes)
382392
(
383393
self._dot_product_equation,
384394
self._combine_equation,
@@ -760,6 +770,12 @@ def _build_attention_equation(rank, attn_axes):
760770
Returns:
761771
Einsum equations.
762772
"""
773+
# Normalize negative indices to positive indices
774+
if isinstance(attn_axes, (list, tuple)):
775+
attn_axes = tuple(ax % rank if ax < 0 else ax for ax in attn_axes)
776+
else:
777+
attn_axes = (attn_axes % rank if attn_axes < 0 else attn_axes,)
778+
763779
target_notation = ""
764780
for i in range(rank):
765781
target_notation += _index_to_einsum_variable(i)

0 commit comments

Comments
 (0)