Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
Expand Down
28 changes: 28 additions & 0 deletions integration_tests/test_multi_head_attention_negative_axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np

import keras


def test_attention_axes_negative_indexing_matches_positive():
x = np.random.normal(size=(2, 3, 8, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to multi_head_attention_test.py and use the unit test style, i.e. self.assertEqual, self.assertAllClose, ...


mha_pos = keras.layers.MultiHeadAttention(
num_heads=2, key_dim=4, attention_axes=2
)
mha_neg = keras.layers.MultiHeadAttention(
num_heads=2, key_dim=4, attention_axes=-2
)

_ = mha_pos(x, x)
_ = mha_neg(x, x)

mha_neg.set_weights(mha_pos.get_weights())

z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)

assert z_pos.shape == z_neg.shape
assert a_pos.shape == a_neg.shape

np.testing.assert_allclose(z_pos, z_neg, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(a_pos, a_neg, rtol=1e-5, atol=1e-5)
27 changes: 27 additions & 0 deletions integration_tests/test_save_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated changes, please revert.


import numpy as np
import pytest

from keras.utils import img_to_array
from keras.utils import load_img
from keras.utils import save_img


@pytest.mark.parametrize(
"shape, name",
[
((50, 50, 3), "rgb.jpg"),
((50, 50, 4), "rgba.jpg"),
],
)
def test_save_jpg(tmp_path, shape, name):
img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
path = tmp_path / name
save_img(path, img, file_format="jpg")
assert os.path.exists(path)

# Check that the image was saved correctly and converted to RGB if needed.
loaded_img = load_img(path)
loaded_array = img_to_array(loaded_img)
assert loaded_array.shape == (50, 50, 3)
18 changes: 17 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,17 @@ def _build_attention(self, rank):
if self._attention_axes is None:
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
# Normalize negative indices relative to INPUT rank (rank - 1)
input_rank = rank - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rank - 1?

I think this would be enough instead of lines 381-391:

self._attention_axes = tuple(axis + rank if axis < 0 else axis for axis in self._attention_axes)

normalized_axes = []
for ax in self._attention_axes:
if ax < 0:
# Normalize relative to input rank
normalized_ax = input_rank + ax
else:
normalized_ax = ax
normalized_axes.append(normalized_ax)
self._attention_axes = tuple(normalized_axes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and conciseness, this loop for normalizing axes can be simplified into a single list comprehension.

            input_rank = rank - 1
            self._attention_axes = tuple(
                input_rank + ax if ax < 0 else ax for ax in self._attention_axes
            )

(
self._dot_product_equation,
self._combine_equation,
Expand Down Expand Up @@ -760,6 +770,12 @@ def _build_attention_equation(rank, attn_axes):
Returns:
Einsum equations.
"""
# Normalize negative indices to positive indices
if isinstance(attn_axes, (list, tuple)):
attn_axes = tuple(ax % rank if ax < 0 else ax for ax in attn_axes)
else:
attn_axes = (attn_axes % rank if attn_axes < 0 else attn_axes,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block for normalizing negative indices appears to be both redundant and incorrect.

  1. Redundant: The _build_attention method already normalizes self._attention_axes to positive indices before passing them to this function. This block will therefore have no effect on the already-positive axes.
  2. Incorrect: If this block were to handle negative indices, its logic ax % rank is incorrect. It normalizes based on the projected tensor's rank, which is the exact bug this PR aims to fix. The correct normalization should be relative to the input rank, as correctly implemented in _build_attention.

To avoid confusion and potential future bugs, this block should be removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, please revert.


target_notation = ""
for i in range(rank):
target_notation += _index_to_einsum_variable(i)
Expand Down
7 changes: 5 additions & 2 deletions keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)
# Normalize jpg → jpeg
if file_format is not None and file_format.lower() == "jpg":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated changes, please revert.

file_format = "jpeg"
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
if img.mode == "RGBA" and file_format == "jpeg":
warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
"The JPEG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)
Expand Down