Skip to content

Commit 7481c8e

Browse files
committed
fix self attn mask comp
1 parent f0c412f commit 7481c8e

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -378,39 +378,35 @@ def _compute_self_attention_mask(
378378
decoder_sequence,
379379
decoder_padding_mask,
380380
decoder_attention_mask,
381-
use_causal_mask,
382381
self_attention_cache,
383382
self_attention_cache_update_index,
384383
):
385384
decoder_mask = merge_padding_and_attention_mask(
386385
decoder_sequence, decoder_padding_mask, decoder_attention_mask
387386
)
388-
if use_causal_mask:
389-
batch_size = ops.shape(decoder_sequence)[0]
390-
input_length = output_length = ops.shape(decoder_sequence)[1]
391-
# We need to handle a rectangular causal mask when doing cached
392-
# decoding. For generative inference, `decoder_sequence` will
393-
# generally be length 1, and `cache` will be the full generation
394-
# length.
395-
if self_attention_cache is not None:
396-
input_length = ops.shape(self_attention_cache)[2]
397-
398-
causal_mask = compute_causal_mask(
399-
batch_size,
400-
input_length,
401-
output_length,
402-
(
403-
0
404-
if self_attention_cache_update_index is None
405-
else self_attention_cache_update_index
406-
),
407-
)
408-
return (
409-
ops.minimum(decoder_mask, causal_mask)
410-
if decoder_mask is not None
411-
else causal_mask
412-
)
413-
return decoder_mask
387+
batch_size = ops.shape(decoder_sequence)[0]
388+
input_length = output_length = ops.shape(decoder_sequence)[1]
389+
# We need to handle a rectangular causal mask when doing cached
390+
# decoding. For generative inference, `decoder_sequence` will
391+
# generally be length 1, and `cache` will be the full generation length.
392+
if self_attention_cache is not None:
393+
input_length = ops.shape(self_attention_cache)[2]
394+
395+
cache_update_index = (
396+
0
397+
if self_attention_cache_update_index is None
398+
else self_attention_cache_update_index
399+
)
400+
401+
causal_mask = compute_causal_mask(
402+
batch_size, input_length, output_length, cache_update_index
403+
)
404+
405+
return (
406+
ops.minimum(decoder_mask, causal_mask)
407+
if decoder_mask is not None
408+
else causal_mask
409+
)
414410

415411
def build(self, input_shape):
416412
"""

0 commit comments

Comments
 (0)