Skip to content

Commit df002a9

Browse files
authored
Fix: Return Attention Scores when return_attention_scores=True (#20684)
* Fix: Ensure Attention Layer Returns Attention Scores when `return_attention_scores=True` This pull request addresses an issue in the Attention layer where the return_attention_scores parameter wasn't correctly handled in the compute_output_shape method. This fix ensures that attention scores are returned when return_attention_scores=True. ## Changes Made Modified compute_output_shape method to return the shape of both the attention output and the attention scores when return_attention_scores=True. * Formatting * Fixed score return and added unit tests for return_attention_scores=True * Removed debug print statement
1 parent c1316e5 commit df002a9

File tree

2 files changed

+111
-6
lines changed

2 files changed

+111
-6
lines changed

keras/src/layers/attention/attention.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from keras.src import backend
22
from keras.src import ops
33
from keras.src.api_export import keras_export
4+
from keras.src.backend import KerasTensor
45
from keras.src.layers.layer import Layer
56

67

@@ -84,6 +85,8 @@ def __init__(
8485
f"Received: score_mode={score_mode}"
8586
)
8687

88+
self._return_attention_scores = False
89+
8790
def build(self, input_shape):
8891
self._validate_inputs(input_shape)
8992
self.scale = None
@@ -217,6 +220,7 @@ def call(
217220
use_causal_mask=False,
218221
):
219222
self._validate_inputs(inputs=inputs, mask=mask)
223+
self._return_attention_scores = return_attention_scores
220224
q = inputs[0]
221225
v = inputs[1]
222226
k = inputs[2] if len(inputs) > 2 else v
@@ -226,16 +230,17 @@ def call(
226230
scores_mask = self._calculate_score_mask(
227231
scores, v_mask, use_causal_mask
228232
)
229-
result, attention_scores = self._apply_scores(
233+
attention_output, attention_scores = self._apply_scores(
230234
scores=scores, value=v, scores_mask=scores_mask, training=training
231235
)
232236
if q_mask is not None:
233237
# Mask of shape [batch_size, Tq, 1].
234238
q_mask = ops.expand_dims(q_mask, axis=-1)
235-
result *= ops.cast(q_mask, dtype=result.dtype)
239+
attention_output *= ops.cast(q_mask, dtype=attention_output.dtype)
236240
if return_attention_scores:
237-
return result, attention_scores
238-
return result
241+
return (attention_output, attention_scores)
242+
else:
243+
return attention_output
239244

240245
def compute_mask(self, inputs, mask=None):
241246
self._validate_inputs(inputs=inputs, mask=mask)
@@ -244,8 +249,49 @@ def compute_mask(self, inputs, mask=None):
244249
return ops.convert_to_tensor(mask[0])
245250

246251
def compute_output_shape(self, input_shape):
247-
"""Returns shape of value tensor dim, but for query tensor length"""
248-
return (*input_shape[0][:-1], input_shape[1][-1])
252+
query_shape, value_shape, key_shape = input_shape
253+
if key_shape is None:
254+
key_shape = value_shape
255+
256+
output_shape = (*query_shape[:-1], value_shape[-1])
257+
if self._return_attention_scores:
258+
scores_shape = (query_shape[0], query_shape[1], key_shape[1])
259+
return output_shape, scores_shape
260+
return output_shape
261+
262+
def compute_output_spec(
263+
self,
264+
inputs,
265+
mask=None,
266+
return_attention_scores=False,
267+
training=None,
268+
use_causal_mask=False,
269+
):
270+
# Validate and unpack inputs
271+
self._validate_inputs(inputs, mask)
272+
query = inputs[0]
273+
value = inputs[1]
274+
key = inputs[2] if len(inputs) > 2 else value
275+
276+
# Compute primary output shape
277+
output_shape = self.compute_output_shape(
278+
[query.shape, value.shape, key.shape]
279+
)
280+
output_spec = KerasTensor(output_shape, dtype=self.compute_dtype)
281+
282+
# Handle attention scores if requested
283+
if self._return_attention_scores:
284+
scores_shape = (
285+
query.shape[0],
286+
query.shape[1],
287+
key.shape[1],
288+
) # (batch_size, Tq, Tv)
289+
attention_scores_spec = KerasTensor(
290+
scores_shape, dtype=self.compute_dtype
291+
)
292+
return (output_spec, attention_scores_spec)
293+
294+
return output_spec
249295

250296
def _validate_inputs(self, inputs, mask=None):
251297
"""Validates arguments of the call method."""

keras/src/layers/attention/attention_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,62 @@ def test_attention_compute_output_shape(self):
358358
),
359359
output.shape,
360360
)
361+
362+
def test_return_attention_scores_true(self):
363+
"""Test that the layer returns attention scores along with outputs."""
364+
# Generate dummy input data
365+
query = np.random.random((2, 8, 16)).astype(np.float32)
366+
value = np.random.random((2, 4, 16)).astype(np.float32)
367+
368+
# Initialize the Attention layer
369+
layer = layers.Attention()
370+
371+
# Call the layer with return_attention_scores=True
372+
output, attention_scores = layer(
373+
[query, value], return_attention_scores=True
374+
)
375+
376+
# Check the shape of the outputs
377+
self.assertEqual(output.shape, (2, 8, 16)) # Output shape
378+
self.assertEqual(
379+
attention_scores.shape, (2, 8, 4)
380+
) # Attention scores shape
381+
382+
def test_return_attention_scores_true_and_tuple(self):
383+
"""Test that the layer outputs are a tuple when
384+
return_attention_scores=True."""
385+
# Generate dummy input data
386+
query = np.random.random((2, 8, 16)).astype(np.float32)
387+
value = np.random.random((2, 4, 16)).astype(np.float32)
388+
389+
# Initialize the Attention layer
390+
layer = layers.Attention()
391+
392+
# Call the layer with return_attention_scores=True
393+
outputs = layer([query, value], return_attention_scores=True)
394+
395+
# Check that outputs is a tuple
396+
self.assertIsInstance(
397+
outputs, tuple, "Expected the outputs to be a tuple"
398+
)
399+
400+
def test_return_attention_scores_true_tuple_then_unpack(self):
401+
"""Test that outputs can be unpacked correctly."""
402+
# Generate dummy input data
403+
query = np.random.random((2, 8, 16)).astype(np.float32)
404+
value = np.random.random((2, 4, 16)).astype(np.float32)
405+
406+
# Initialize the Attention layer
407+
layer = layers.Attention()
408+
409+
# Call the layer with return_attention_scores=True
410+
outputs = layer([query, value], return_attention_scores=True)
411+
412+
# Unpack the outputs
413+
output, attention_scores = outputs
414+
415+
# Check the shape of the unpacked outputs
416+
self.assertEqual(output.shape, (2, 8, 16)) # Output shape
417+
self.assertEqual(
418+
attention_scores.shape, (2, 8, 4)
419+
) # Attention scores shape

0 commit comments

Comments
 (0)