Skip to content

Commit 6a53a7d

Browse files
committed
Fix calls within causal model
1 parent b9e458d commit 6a53a7d

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import keras
22

33
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
5+
compute_causal_mask,
6+
)
47
from keras_hub.src.models.backbone import Backbone
58
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
69
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding
@@ -66,6 +69,7 @@ def __init__(
6669
max_position_embeddings,
6770
rope_theta,
6871
partial_rotary_factor,
72+
num_hidden_layers,
6973
**kwargs,
7074
):
7175
# === Layers ===
@@ -109,16 +113,21 @@ def __init__(
109113
token_id_input = keras.Input(
110114
shape=(None,), dtype="int32", name="token_ids"
111115
)
112-
padding_mask_input = keras.Input(
113-
shape=(None,), dtype="int32", name="padding_mask"
116+
position_ids = keras.Input(
117+
shape=(None,), dtype="int32", name="position_ids"
114118
)
115-
x = self.token_embedding(token_id_input)
116-
position_embeddings = self.rotary_embedding(x)
117119

118-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
120+
hidden_states = self.token_embedding(token_id_input)
121+
position_embeddings = self.rotary_embedding(hidden_states, position_ids)
122+
123+
for decoder_layer in self.layers[:num_hidden_layers]:
119124
hidden_states = decoder_layer(
120125
hidden_states,
121-
attention_mask=#createcausalmask,
126+
attention_mask=compute_causal_mask(
127+
hidden_states.shape[0],
128+
hidden_states.shape[1],
129+
hidden_states.shape[1],
130+
),
122131
position_embeddings=position_embeddings,
123132
**kwargs,
124133
)
@@ -127,7 +136,6 @@ def __init__(
127136
super().__init__(
128137
inputs={
129138
"token_ids": token_id_input,
130-
"padding_mask": padding_mask_input,
131139
},
132140
outputs=sequence_output,
133141
**kwargs,
@@ -137,7 +145,6 @@ def __init__(
137145
self.vocabulary_size = vocabulary_size
138146
self.num_layers = num_layers
139147

140-
141148
def get_config(self):
142149
config = super().get_config()
143150
config.update(
@@ -150,4 +157,3 @@ def get_config(self):
150157
}
151158
)
152159
return config
153-

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(
169169
layer_idx: int,
170170
intermediate_size: int,
171171
mlp_bias: bool,
172-
rms_norm_eps: float,
172+
rms_norm_epsilon: float,
173173
**kwargs,
174174
):
175175
super().__init__(**kwargs)
@@ -196,10 +196,10 @@ def __init__(
196196
)
197197

198198
self.input_layernorm = layers.RMSNormalization(
199-
epsilon=rms_norm_eps, axis=-1, name="input_layernorm"
199+
epsilon=rms_norm_epsilon, axis=-1, name="input_layernorm"
200200
)
201201
self.post_attention_layernorm = layers.RMSNormalization(
202-
epsilon=rms_norm_eps, axis=-1, name="post_attention_layernorm"
202+
epsilon=rms_norm_epsilon, axis=-1, name="post_attention_layernorm"
203203
)
204204

205205
self.attention_type = layer_types[layer_idx]

0 commit comments

Comments
 (0)