Skip to content

Commit 671c261

Browse files
committed
actually pass attn mask
1 parent 00c025a commit 671c261

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ def __init__(
128128
hidden_states = decoder_layer(
129129
hidden_states,
130130
position_embeddings=position_embeddings,
131-
decoder_padding_mask=padding_mask_input
132-
**kwargs,
131+
decoder_padding_mask=padding_mask_input**kwargs,
133132
)
134133

135134
sequence_output = self.norm(hidden_states)

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
from keras import layers
44
from keras import ops
55

6-
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
7-
from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward
8-
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
96
from keras_hub.src.layers.modeling.transformer_layer_utils import (
10-
merge_padding_and_attention_mask,
7+
compute_causal_mask,
118
)
129
from keras_hub.src.layers.modeling.transformer_layer_utils import (
13-
compute_causal_mask,
10+
merge_padding_and_attention_mask,
1411
)
12+
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
13+
from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward
14+
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
15+
1516

1617
class SmolLM3Attention(layers.Layer):
1718
"""
@@ -372,7 +373,6 @@ def __init__(
372373

373374
self.attention_type = layer_types[layer_idx]
374375

375-
376376
def _compute_self_attention_mask(
377377
self,
378378
decoder_sequence,
@@ -460,7 +460,9 @@ def call(
460460
training: Whether the layer is in training mode.
461461
"""
462462
self_attention_cache = kwargs.get("self_attention_cache", None)
463-
self_attention_cache_update_index = kwargs.get("self_attention_cache_update_index", None)
463+
self_attention_cache_update_index = kwargs.get(
464+
"self_attention_cache_update_index", None
465+
)
464466

465467
self_attention_mask = self._compute_self_attention_mask(
466468
decoder_sequence=hidden_states,

0 commit comments

Comments
 (0)