Skip to content

Commit 8907bcb

Browse files
authored
fix attention output with symbolic tensors and attention scores (#20689)
1 parent f54c127 commit 8907bcb

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

keras/src/layers/attention/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def compute_output_spec(
280280
output_spec = KerasTensor(output_shape, dtype=self.compute_dtype)
281281

282282
# Handle attention scores if requested
283-
if self._return_attention_scores:
283+
if self._return_attention_scores or return_attention_scores:
284284
scores_shape = (
285285
query.shape[0],
286286
query.shape[1],

keras/src/layers/attention/attention_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,15 @@ def test_return_attention_scores_true_tuple_then_unpack(self):
417417
self.assertEqual(
418418
attention_scores.shape, (2, 8, 4)
419419
) # Attention scores shape
420+
421+
def test_return_attention_scores_with_symbolic_tensors(self):
422+
"""Test to check outputs with symbolic tensors with
423+
return_attention_scores = True"""
424+
attention = layers.Attention()
425+
x = layers.Input(shape=(3, 5))
426+
y = layers.Input(shape=(4, 5))
427+
output, attention_scores = attention(
428+
[x, y], return_attention_scores=True
429+
)
430+
self.assertEqual(output.shape, (None, 3, 5)) # Output shape
431+
self.assertEqual(attention_scores.shape, (None, 3, 4))

0 commit comments

Comments
 (0)