Skip to content

Commit 81eff73

Browse files
committed
Move causal mask computation to forward call
1 parent 6a53a7d commit 81eff73

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import keras
22

33
from keras_hub.src.api_export import keras_hub_export
4-
from keras_hub.src.layers.modeling.transformer_layer_utils import (
5-
compute_causal_mask,
6-
)
74
from keras_hub.src.models.backbone import Backbone
85
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
96
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding
@@ -78,7 +75,7 @@ def __init__(
7875
output_dim=hidden_dim,
7976
name="token_embedding",
8077
)
81-
self.transformer_layers = []
78+
self.decoder_layers = []
8279

8380
for i in range(num_layers):
8481
layer = SmolLM3DecoderLayer(
@@ -94,7 +91,7 @@ def __init__(
9491
mlp_bias=mlp_bias,
9592
rms_norm_epsilon=rms_norm_epsilon,
9693
)
97-
self.transformer_layers.append(layer)
94+
self.decoder_layers.append(layer)
9895

9996
self.norm = keras.layers.RMSNormalization(
10097
epsilon=layer_norm_epsilon,
@@ -117,22 +114,19 @@ def __init__(
117114
shape=(None,), dtype="int32", name="position_ids"
118115
)
119116

117+
print("token id", token_id_input.shape)
120118
hidden_states = self.token_embedding(token_id_input)
119+
print("hidden states id", hidden_states.shape)
121120
position_embeddings = self.rotary_embedding(hidden_states, position_ids)
122121

123-
for decoder_layer in self.layers[:num_hidden_layers]:
122+
for decoder_layer in self.decoder_layers[:num_hidden_layers]:
124123
hidden_states = decoder_layer(
125124
hidden_states,
126-
attention_mask=compute_causal_mask(
127-
hidden_states.shape[0],
128-
hidden_states.shape[1],
129-
hidden_states.shape[1],
130-
),
131125
position_embeddings=position_embeddings,
132126
**kwargs,
133127
)
134128

135-
sequence_output = self.layer_norm(x)
129+
sequence_output = self.layer_norm(hidden_states)
136130
super().__init__(
137131
inputs={
138132
"token_ids": token_id_input,

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from keras import layers
44
from keras import ops
55

6+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
7+
compute_causal_mask,
8+
)
69
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
710
from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward
811
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
@@ -216,14 +219,21 @@ def build(self, input_shape):
216219
def call(
217220
self,
218221
hidden_states,
219-
attention_mask=None,
220222
position_embeddings=None,
221223
training=False,
222224
**kwargs,
223225
):
224226
residual = hidden_states
225227
hidden_states = self.input_layernorm(hidden_states)
226228

229+
attention_mask = (
230+
compute_causal_mask(
231+
ops.shape(hidden_states)[0],
232+
ops.shape(hidden_states)[1],
233+
ops.shape(hidden_states)[1],
234+
),
235+
)
236+
227237
# Self Attention
228238
attn_output, _ = self.self_attn(
229239
hidden_states=hidden_states,

0 commit comments

Comments
 (0)