Skip to content

Commit 79413bc

Browse files
Fix shape mismatch in Keras Attention layer during masking (#21595)
* Fix shape mismatch in Keras Attention layer during masking * Add test case to verify 2D mask shape mismatch * Add test case to verify 2D mask shape mismatch * update unit test with different Tq, Tv * Refactor: update test with distinct Tq, Tv, and dim values
1 parent 7da416d commit 79413bc

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

keras/src/layers/attention/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def _apply_scores(self, scores, value, scores_mask=None, training=False):
176176
# Bias so padding positions do not contribute to attention
177177
# distribution. Note 65504. is the max float16 value.
178178
max_value = 65504.0 if scores.dtype == "float16" else 1.0e9
179+
if len(padding_mask.shape) == 2:
180+
padding_mask = ops.expand_dims(padding_mask, axis=-2)
179181
scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype)
180182

181183
weights = ops.softmax(scores, axis=-1)

keras/src/layers/attention/attention_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,23 @@ def test_attention_with_mask(self):
8686
self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]])
8787
self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]])
8888

89+
def test_attention_2D_mask_shape_mismatch(self):
90+
layer = layers.Attention()
91+
batch_size, Tq, Tv, dim = 2, 3, 4, 5
92+
query = np.random.random((batch_size, Tq, dim)).astype(np.float32)
93+
value = np.random.random((batch_size, Tv, dim)).astype(np.float32)
94+
query_mask = np.array([[True, False, True], [True, False, True]])
95+
value_mask = np.array(
96+
[[True, False, True, True], [True, False, True, True]]
97+
)
98+
output, scores = layer(
99+
[query, value],
100+
mask=[query_mask, value_mask],
101+
return_attention_scores=True,
102+
)
103+
self.assertEqual(output.shape, (batch_size, Tq, dim))
104+
self.assertEqual(scores.shape, (batch_size, Tq, Tv))
105+
89106
def test_attention_errors(self):
90107
layer = layers.Attention()
91108
tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]])

0 commit comments

Comments
 (0)