Skip to content

Commit 00c025a

Browse files
committed
actually pass attn mask
1 parent c45f589 commit 00c025a

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def __init__(
115115
position_id_input = keras.Input(
116116
shape=(None,), dtype="int32", name="position_ids"
117117
)
118+
padding_mask_input = keras.Input(
119+
shape=(None,), dtype="int32", name="padding_mask"
120+
)
118121

119122
hidden_states = self.token_embedding(token_id_input)
120123
position_embeddings = self.rotary_embedding(
@@ -125,6 +128,7 @@ def __init__(
125128
hidden_states = decoder_layer(
126129
hidden_states,
127130
position_embeddings=position_embeddings,
131+
decoder_padding_mask=padding_mask_input
128132
**kwargs,
129133
)
130134

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
77
from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward
88
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
9-
10-
9+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
10+
merge_padding_and_attention_mask,
11+
)
12+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
13+
compute_causal_mask,
14+
)
1115

1216
class SmolLM3Attention(layers.Layer):
1317
"""
@@ -368,6 +372,46 @@ def __init__(
368372

369373
self.attention_type = layer_types[layer_idx]
370374

375+
376+
def _compute_self_attention_mask(
377+
self,
378+
decoder_sequence,
379+
decoder_padding_mask,
380+
decoder_attention_mask,
381+
use_causal_mask,
382+
self_attention_cache,
383+
self_attention_cache_update_index,
384+
):
385+
decoder_mask = merge_padding_and_attention_mask(
386+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
387+
)
388+
if use_causal_mask:
389+
batch_size = ops.shape(decoder_sequence)[0]
390+
input_length = output_length = ops.shape(decoder_sequence)[1]
391+
# We need to handle a rectangular causal mask when doing cached
392+
# decoding. For generative inference, `decoder_sequence` will
393+
# generally be length 1, and `cache` will be the full generation
394+
# length.
395+
if self_attention_cache is not None:
396+
input_length = ops.shape(self_attention_cache)[2]
397+
398+
causal_mask = compute_causal_mask(
399+
batch_size,
400+
input_length,
401+
output_length,
402+
(
403+
0
404+
if self_attention_cache_update_index is None
405+
else self_attention_cache_update_index
406+
),
407+
)
408+
return (
409+
ops.minimum(decoder_mask, causal_mask)
410+
if decoder_mask is not None
411+
else causal_mask
412+
)
413+
return decoder_mask
414+
371415
def build(self, input_shape):
372416
"""
373417
Builds the sub-layers based on the input shape.
@@ -403,6 +447,8 @@ def call(
403447
hidden_states,
404448
position_embeddings=None,
405449
training=False,
450+
decoder_padding_mask=None,
451+
decoder_attention_mask=None,
406452
**kwargs,
407453
):
408454
"""
@@ -414,6 +460,15 @@ def call(
414460
training: Whether the layer is in training mode.
415461
"""
416462
self_attention_cache = kwargs.get("self_attention_cache", None)
463+
self_attention_cache_update_index = kwargs.get("self_attention_cache_update_index", None)
464+
465+
self_attention_mask = self._compute_self_attention_mask(
466+
decoder_sequence=hidden_states,
467+
decoder_padding_mask=decoder_padding_mask,
468+
decoder_attention_mask=decoder_attention_mask,
469+
self_attention_cache=self_attention_cache,
470+
self_attention_cache_update_index=self_attention_cache_update_index,
471+
)
417472

418473
residual = hidden_states
419474
hidden_states = self.input_layernorm(hidden_states)
@@ -423,6 +478,7 @@ def call(
423478
hidden_states=hidden_states,
424479
position_embeddings=position_embeddings,
425480
training=training,
481+
attention_mask=self_attention_mask,
426482
**kwargs,
427483
)
428484

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def eager_attention_forward(
4646
* scaling
4747
)
4848

49-
# Apply attention mask if provided
50-
print("attention_mask", attention_mask)
5149
if attention_mask is not None:
5250
causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]]
5351
attn_weights = ops.add(attn_weights, causal_mask)

0 commit comments

Comments
 (0)