Skip to content

Commit 5be137b

Browse files
committed
pass attention mask
1 parent 8fbabf8 commit 5be137b

File tree

3 files changed

+8
-15
lines changed

3 files changed

+8
-15
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
layer_idx=i,
9191
intermediate_size=intermediate_dim,
9292
mlp_bias=mlp_bias,
93-
rms_norm_epsilon=layer_norm_epsilon,
93+
layer_norm_epsilon=layer_norm_epsilon,
9494
name=f"transformer_layer_{i}",
9595
)
9696
self.transformer_layers.append(layer)

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
99

1010

11+
1112
class SmolLM3Attention(layers.Layer):
1213
"""
1314
Multi-head attention layer for SmolLM3 model.
@@ -94,14 +95,14 @@ def build(self, input_shape):
9495
self.k_proj.build(hidden_states_shape)
9596
self.v_proj.build(hidden_states_shape)
9697
self.o_proj.build(hidden_states_shape)
97-
self.training = False
9898
super().build(input_shape)
9999

100100
def call(
101101
self,
102102
hidden_states,
103103
position_embeddings,
104104
training=False,
105+
attention_mask=None,
105106
**kwargs,
106107
):
107108
"""
@@ -142,19 +143,14 @@ def _compute_kv_values(x_input):
142143
value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3))
143144
return key_states, value_states
144145

145-
print("self_attention_cache is ", self_attention_cache)
146146
if self_attention_cache is not None:
147147
key_cache = self_attention_cache[:, 0, ...]
148148
value_cache = self_attention_cache[:, 1, ...]
149149

150150
if self_attention_cache_update_index is None:
151-
print("self_attention_cache_update_index is None")
152151
key_states = key_cache
153152
value_states = value_cache
154153
else:
155-
print(
156-
"self_attention_cache_update_index is not None, computing kv values"
157-
)
158154
key_update, value_update = _compute_kv_values(hidden_states)
159155
update_idx_tensor = ops.convert_to_tensor(
160156
self_attention_cache_update_index, dtype="int32"
@@ -190,6 +186,7 @@ def _compute_kv_values(x_input):
190186
dropout=self.attention_dropout,
191187
scaling=self.scaling,
192188
training=self.training,
189+
attention_mask=attention_mask,
193190
)
194191

195192
attn_output = ops.reshape(attn_output, (*input_shape, self.hidden_size))
@@ -277,7 +274,6 @@ def build(self, input_shape):
277274
self.intermediate_size,
278275
)
279276
self.down_proj.build(down_proj_input_shape)
280-
self.training = False
281277
super().build(input_shape)
282278

283279
def call(self, x):
@@ -322,7 +318,7 @@ class SmolLM3DecoderLayer(layers.Layer):
322318
layer_idx: Index of the current layer.
323319
intermediate_size: The intermediate size of the MLP.
324320
mlp_bias: Whether to use bias in MLP dense layers.
325-
rms_norm_epsilon: Epsilon for RMSNormalization.
321+
layer_norm_epsilon: Epsilon for RMSNormalization.
326322
"""
327323

328324
def __init__(
@@ -337,7 +333,7 @@ def __init__(
337333
layer_idx: int,
338334
intermediate_size: int,
339335
mlp_bias: bool,
340-
rms_norm_epsilon: float,
336+
layer_norm_epsilon: float,
341337
**kwargs,
342338
):
343339
super().__init__(**kwargs)
@@ -364,10 +360,10 @@ def __init__(
364360
)
365361

366362
self.input_layernorm = layers.RMSNormalization(
367-
epsilon=rms_norm_epsilon, axis=-1, name="input_layernorm"
363+
epsilon=layer_norm_epsilon, axis=-1, name="input_layernorm"
368364
)
369365
self.post_attention_layernorm = layers.RMSNormalization(
370-
epsilon=rms_norm_epsilon, axis=-1, name="post_attention_layernorm"
366+
epsilon=layer_norm_epsilon, axis=-1, name="post_attention_layernorm"
371367
)
372368

373369
self.attention_type = layer_types[layer_idx]
@@ -399,7 +395,6 @@ def build(self, input_shape):
399395
self.mlp.build(input_shape)
400396
self.input_layernorm.build(input_shape)
401397
self.post_attention_layernorm.build(input_shape)
402-
self.training = False
403398

404399
super().build(input_shape)
405400

@@ -518,7 +513,6 @@ def build(self, input_shape):
518513
- position_ids_shape: (batch_size, seq_len)
519514
"""
520515
# No internal layers to explicitly build here, as inv_freq is added in __init__
521-
self.training = False
522516
super().build(input_shape)
523517

524518
def call(

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def eager_attention_forward(
3838
dropout=0.0,
3939
training=False,
4040
):
41-
print('training', training)
4241
key_states = repeat_kv(key, module.num_key_value_groups)
4342
value_states = repeat_kv(value, module.num_key_value_groups)
4443

0 commit comments

Comments
 (0)