diff --git a/.gitignore b/.gitignore index afd700b4995..416f213f2c8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ **/.vscode test/** **/.vscode-smoke/** **/.venv*/ +venv bin/** build/** obj/** diff --git a/keras/src/layers/attention/multi_head_attention.py b/keras/src/layers/attention/multi_head_attention.py index a8aa86838d5..4cf70ee2c11 100644 --- a/keras/src/layers/attention/multi_head_attention.py +++ b/keras/src/layers/attention/multi_head_attention.py @@ -378,7 +378,10 @@ def _build_attention(self, rank): if self._attention_axes is None: self._attention_axes = tuple(range(1, rank - 2)) else: - self._attention_axes = tuple(self._attention_axes) + self._attention_axes = tuple( + axis if axis >= 0 else (rank - 1) + axis + for axis in self._attention_axes + ) ( self._dot_product_equation, self._combine_equation, diff --git a/keras/src/layers/attention/multi_head_attention_test.py b/keras/src/layers/attention/multi_head_attention_test.py index d74abbd8841..e284635053c 100644 --- a/keras/src/layers/attention/multi_head_attention_test.py +++ b/keras/src/layers/attention/multi_head_attention_test.py @@ -203,6 +203,36 @@ def test_high_dim_attention( run_training_check=False, ) + def test_attention_axes_negative_indexing(self): + x = np.random.normal(size=(2, 3, 8, 4)) + + # Create two layers with equivalent positive and negative indices + mha_pos = layers.MultiHeadAttention( + num_heads=2, key_dim=4, attention_axes=2 + ) + mha_neg = layers.MultiHeadAttention( + num_heads=2, key_dim=4, attention_axes=-2 + ) + + # Initialize both layers + _ = mha_pos(x, x) + _ = mha_neg(x, x) + + # Set same weights for fair comparison + mha_neg.set_weights(mha_pos.get_weights()) + + # Get outputs and attention scores + z_pos, a_pos = mha_pos(x, x, return_attention_scores=True) + z_neg, a_neg = mha_neg(x, x, return_attention_scores=True) + + # Verify shapes match + self.assertEqual(z_pos.shape, z_neg.shape) + self.assertEqual(a_pos.shape, a_neg.shape) + + # Verify outputs are identical + self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5) + self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5) + @parameterized.named_parameters( ("without_key_same_proj", (4, 8), (2, 8), None, None), ("with_key_same_proj", (4, 8), (2, 8), (2, 3), None),