6
6
from keras_hub .src .models .smollm3 .smollm3_utils import apply_rotary_pos_emb
7
7
from keras_hub .src .models .smollm3 .smollm3_utils import eager_attention_forward
8
8
from keras_hub .src .models .smollm3 .smollm3_utils import rope_init
9
-
10
-
9
+ from keras_hub .src .layers .modeling .transformer_layer_utils import (
10
+ merge_padding_and_attention_mask ,
11
+ )
12
+ from keras_hub .src .layers .modeling .transformer_layer_utils import (
13
+ compute_causal_mask ,
14
+ )
11
15
12
16
class SmolLM3Attention (layers .Layer ):
13
17
"""
@@ -368,6 +372,46 @@ def __init__(
368
372
369
373
self .attention_type = layer_types [layer_idx ]
370
374
375
+
376
+ def _compute_self_attention_mask (
377
+ self ,
378
+ decoder_sequence ,
379
+ decoder_padding_mask ,
380
+ decoder_attention_mask ,
381
+ use_causal_mask ,
382
+ self_attention_cache ,
383
+ self_attention_cache_update_index ,
384
+ ):
385
+ decoder_mask = merge_padding_and_attention_mask (
386
+ decoder_sequence , decoder_padding_mask , decoder_attention_mask
387
+ )
388
+ if use_causal_mask :
389
+ batch_size = ops .shape (decoder_sequence )[0 ]
390
+ input_length = output_length = ops .shape (decoder_sequence )[1 ]
391
+ # We need to handle a rectangular causal mask when doing cached
392
+ # decoding. For generative inference, `decoder_sequence` will
393
+ # generally be length 1, and `cache` will be the full generation
394
+ # length.
395
+ if self_attention_cache is not None :
396
+ input_length = ops .shape (self_attention_cache )[2 ]
397
+
398
+ causal_mask = compute_causal_mask (
399
+ batch_size ,
400
+ input_length ,
401
+ output_length ,
402
+ (
403
+ 0
404
+ if self_attention_cache_update_index is None
405
+ else self_attention_cache_update_index
406
+ ),
407
+ )
408
+ return (
409
+ ops .minimum (decoder_mask , causal_mask )
410
+ if decoder_mask is not None
411
+ else causal_mask
412
+ )
413
+ return decoder_mask
414
+
371
415
def build (self , input_shape ):
372
416
"""
373
417
Builds the sub-layers based on the input shape.
@@ -403,6 +447,8 @@ def call(
403
447
hidden_states ,
404
448
position_embeddings = None ,
405
449
training = False ,
450
+ decoder_padding_mask = None ,
451
+ decoder_attention_mask = None ,
406
452
** kwargs ,
407
453
):
408
454
"""
@@ -414,6 +460,15 @@ def call(
414
460
training: Whether the layer is in training mode.
415
461
"""
416
462
self_attention_cache = kwargs .get ("self_attention_cache" , None )
463
+ self_attention_cache_update_index = kwargs .get ("self_attention_cache_update_index" , None )
464
+
465
+ self_attention_mask = self ._compute_self_attention_mask (
466
+ decoder_sequence = hidden_states ,
467
+ decoder_padding_mask = decoder_padding_mask ,
468
+ decoder_attention_mask = decoder_attention_mask ,
469
+ self_attention_cache = self_attention_cache ,
470
+ self_attention_cache_update_index = self_attention_cache_update_index ,
471
+ )
417
472
418
473
residual = hidden_states
419
474
hidden_states = self .input_layernorm (hidden_states )
@@ -423,6 +478,7 @@ def call(
423
478
hidden_states = hidden_states ,
424
479
position_embeddings = position_embeddings ,
425
480
training = training ,
481
+ attention_mask = self_attention_mask ,
426
482
** kwargs ,
427
483
)
428
484
0 commit comments