8
8
from keras_hub .src .models .smollm3 .smollm3_utils import rope_init
9
9
10
10
11
+
11
12
class SmolLM3Attention (layers .Layer ):
12
13
"""
13
14
Multi-head attention layer for SmolLM3 model.
@@ -94,14 +95,14 @@ def build(self, input_shape):
94
95
self .k_proj .build (hidden_states_shape )
95
96
self .v_proj .build (hidden_states_shape )
96
97
self .o_proj .build (hidden_states_shape )
97
- self .training = False
98
98
super ().build (input_shape )
99
99
100
100
def call (
101
101
self ,
102
102
hidden_states ,
103
103
position_embeddings ,
104
104
training = False ,
105
+ attention_mask = None ,
105
106
** kwargs ,
106
107
):
107
108
"""
@@ -142,19 +143,14 @@ def _compute_kv_values(x_input):
142
143
value_states = ops .transpose (value_states_raw , axes = (0 , 2 , 1 , 3 ))
143
144
return key_states , value_states
144
145
145
- print ("self_attention_cache is " , self_attention_cache )
146
146
if self_attention_cache is not None :
147
147
key_cache = self_attention_cache [:, 0 , ...]
148
148
value_cache = self_attention_cache [:, 1 , ...]
149
149
150
150
if self_attention_cache_update_index is None :
151
- print ("self_attention_cache_update_index is None" )
152
151
key_states = key_cache
153
152
value_states = value_cache
154
153
else :
155
- print (
156
- "self_attention_cache_update_index is not None, computing kv values"
157
- )
158
154
key_update , value_update = _compute_kv_values (hidden_states )
159
155
update_idx_tensor = ops .convert_to_tensor (
160
156
self_attention_cache_update_index , dtype = "int32"
@@ -190,6 +186,7 @@ def _compute_kv_values(x_input):
190
186
dropout = self .attention_dropout ,
191
187
scaling = self .scaling ,
192
188
training = self .training ,
189
+ attention_mask = attention_mask ,
193
190
)
194
191
195
192
attn_output = ops .reshape (attn_output , (* input_shape , self .hidden_size ))
@@ -277,7 +274,6 @@ def build(self, input_shape):
277
274
self .intermediate_size ,
278
275
)
279
276
self .down_proj .build (down_proj_input_shape )
280
- self .training = False
281
277
super ().build (input_shape )
282
278
283
279
def call (self , x ):
@@ -322,7 +318,7 @@ class SmolLM3DecoderLayer(layers.Layer):
322
318
layer_idx: Index of the current layer.
323
319
intermediate_size: The intermediate size of the MLP.
324
320
mlp_bias: Whether to use bias in MLP dense layers.
325
- rms_norm_epsilon : Epsilon for RMSNormalization.
321
+ layer_norm_epsilon : Epsilon for RMSNormalization.
326
322
"""
327
323
328
324
def __init__ (
@@ -337,7 +333,7 @@ def __init__(
337
333
layer_idx : int ,
338
334
intermediate_size : int ,
339
335
mlp_bias : bool ,
340
- rms_norm_epsilon : float ,
336
+ layer_norm_epsilon : float ,
341
337
** kwargs ,
342
338
):
343
339
super ().__init__ (** kwargs )
@@ -364,10 +360,10 @@ def __init__(
364
360
)
365
361
366
362
self .input_layernorm = layers .RMSNormalization (
367
- epsilon = rms_norm_epsilon , axis = - 1 , name = "input_layernorm"
363
+ epsilon = layer_norm_epsilon , axis = - 1 , name = "input_layernorm"
368
364
)
369
365
self .post_attention_layernorm = layers .RMSNormalization (
370
- epsilon = rms_norm_epsilon , axis = - 1 , name = "post_attention_layernorm"
366
+ epsilon = layer_norm_epsilon , axis = - 1 , name = "post_attention_layernorm"
371
367
)
372
368
373
369
self .attention_type = layer_types [layer_idx ]
@@ -399,7 +395,6 @@ def build(self, input_shape):
399
395
self .mlp .build (input_shape )
400
396
self .input_layernorm .build (input_shape )
401
397
self .post_attention_layernorm .build (input_shape )
402
- self .training = False
403
398
404
399
super ().build (input_shape )
405
400
@@ -518,7 +513,6 @@ def build(self, input_shape):
518
513
- position_ids_shape: (batch_size, seq_len)
519
514
"""
520
515
# No internal layers to explicitly build here, as inv_freq is added in __init__
521
- self .training = False
522
516
super ().build (input_shape )
523
517
524
518
def call (
0 commit comments