-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix negative index handling in MultiHeadAttention attention_axes #21721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
2b65f89
2a5bb21
deebbc6
17731f1
2588cd1
a05c61c
23f2a18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ __pycache__ | |
**/.vscode test/** | ||
**/.vscode-smoke/** | ||
**/.venv*/ | ||
venv | ||
bin/** | ||
build/** | ||
obj/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import numpy as np | ||
|
||
import keras | ||
|
||
|
||
def test_attention_axes_negative_indexing_matches_positive(): | ||
x = np.random.normal(size=(2, 3, 8, 4)) | ||
|
||
mha_pos = keras.layers.MultiHeadAttention( | ||
num_heads=2, key_dim=4, attention_axes=2 | ||
) | ||
mha_neg = keras.layers.MultiHeadAttention( | ||
num_heads=2, key_dim=4, attention_axes=-2 | ||
) | ||
|
||
_ = mha_pos(x, x) | ||
_ = mha_neg(x, x) | ||
|
||
mha_neg.set_weights(mha_pos.get_weights()) | ||
|
||
z_pos, a_pos = mha_pos(x, x, return_attention_scores=True) | ||
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True) | ||
|
||
assert z_pos.shape == z_neg.shape | ||
assert a_pos.shape == a_neg.shape | ||
|
||
np.testing.assert_allclose(z_pos, z_neg, rtol=1e-5, atol=1e-5) | ||
np.testing.assert_allclose(a_pos, a_neg, rtol=1e-5, atol=1e-5) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
|
||
import numpy as np | ||
import pytest | ||
|
||
from keras.utils import img_to_array | ||
from keras.utils import load_img | ||
from keras.utils import save_img | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape, name", | ||
[ | ||
((50, 50, 3), "rgb.jpg"), | ||
((50, 50, 4), "rgba.jpg"), | ||
], | ||
) | ||
def test_save_jpg(tmp_path, shape, name): | ||
img = np.random.randint(0, 256, size=shape, dtype=np.uint8) | ||
path = tmp_path / name | ||
save_img(path, img, file_format="jpg") | ||
assert os.path.exists(path) | ||
|
||
# Check that the image was saved correctly and converted to RGB if needed. | ||
loaded_img = load_img(path) | ||
loaded_array = img_to_array(loaded_img) | ||
assert loaded_array.shape == (50, 50, 3) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -378,7 +378,17 @@ def _build_attention(self, rank): | |
if self._attention_axes is None: | ||
self._attention_axes = tuple(range(1, rank - 2)) | ||
else: | ||
self._attention_axes = tuple(self._attention_axes) | ||
# Normalize negative indices relative to INPUT rank (rank - 1) | ||
input_rank = rank - 1 | ||
|
||
normalized_axes = [] | ||
for ax in self._attention_axes: | ||
if ax < 0: | ||
# Normalize relative to input rank | ||
normalized_ax = input_rank + ax | ||
else: | ||
normalized_ax = ax | ||
normalized_axes.append(normalized_ax) | ||
self._attention_axes = tuple(normalized_axes) | ||
|
||
( | ||
self._dot_product_equation, | ||
self._combine_equation, | ||
|
@@ -760,6 +770,12 @@ def _build_attention_equation(rank, attn_axes): | |
Returns: | ||
Einsum equations. | ||
""" | ||
# Normalize negative indices to positive indices | ||
if isinstance(attn_axes, (list, tuple)): | ||
attn_axes = tuple(ax % rank if ax < 0 else ax for ax in attn_axes) | ||
else: | ||
attn_axes = (attn_axes % rank if attn_axes < 0 else attn_axes,) | ||
|
||
|
||
target_notation = "" | ||
for i in range(rank): | ||
target_notation += _index_to_einsum_variable(i) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): | |
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`. | ||
""" | ||
data_format = backend.standardize_data_format(data_format) | ||
# Normalize jpg → jpeg | ||
if file_format is not None and file_format.lower() == "jpg": | ||
|
||
file_format = "jpeg" | ||
img = array_to_img(x, data_format=data_format, scale=scale) | ||
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): | ||
if img.mode == "RGBA" and file_format == "jpeg": | ||
warnings.warn( | ||
"The JPG format does not support RGBA images, converting to RGB." | ||
"The JPEG format does not support RGBA images, converting to RGB." | ||
) | ||
img = img.convert("RGB") | ||
img.save(path, format=file_format, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to
multi_head_attention_test.py
and use the unit test style, i.e.self.assertEqual
,self.assertAllClose
, ...