1
1
import keras
2
2
3
3
from keras_hub .src .api_export import keras_hub_export
4
+ from keras_hub .src .layers .modeling .transformer_layer_utils import (
5
+ compute_causal_mask ,
6
+ )
4
7
from keras_hub .src .models .backbone import Backbone
5
8
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3DecoderLayer
6
9
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3RotaryEmbedding
@@ -66,6 +69,7 @@ def __init__(
66
69
max_position_embeddings ,
67
70
rope_theta ,
68
71
partial_rotary_factor ,
72
+ num_hidden_layers ,
69
73
** kwargs ,
70
74
):
71
75
# === Layers ===
@@ -109,16 +113,21 @@ def __init__(
109
113
token_id_input = keras .Input (
110
114
shape = (None ,), dtype = "int32" , name = "token_ids"
111
115
)
112
- padding_mask_input = keras .Input (
113
- shape = (None ,), dtype = "int32" , name = "padding_mask "
116
+ position_ids = keras .Input (
117
+ shape = (None ,), dtype = "int32" , name = "position_ids "
114
118
)
115
- x = self .token_embedding (token_id_input )
116
- position_embeddings = self .rotary_embedding (x )
117
119
118
- for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
120
+ hidden_states = self .token_embedding (token_id_input )
121
+ position_embeddings = self .rotary_embedding (hidden_states , position_ids )
122
+
123
+ for decoder_layer in self .layers [:num_hidden_layers ]:
119
124
hidden_states = decoder_layer (
120
125
hidden_states ,
121
- attention_mask = #createcausalmask,
126
+ attention_mask = compute_causal_mask (
127
+ hidden_states .shape [0 ],
128
+ hidden_states .shape [1 ],
129
+ hidden_states .shape [1 ],
130
+ ),
122
131
position_embeddings = position_embeddings ,
123
132
** kwargs ,
124
133
)
@@ -127,7 +136,6 @@ def __init__(
127
136
super ().__init__ (
128
137
inputs = {
129
138
"token_ids" : token_id_input ,
130
- "padding_mask" : padding_mask_input ,
131
139
},
132
140
outputs = sequence_output ,
133
141
** kwargs ,
@@ -137,7 +145,6 @@ def __init__(
137
145
self .vocabulary_size = vocabulary_size
138
146
self .num_layers = num_layers
139
147
140
-
141
148
def get_config (self ):
142
149
config = super ().get_config ()
143
150
config .update (
@@ -150,4 +157,3 @@ def get_config(self):
150
157
}
151
158
)
152
159
return config
153
-
0 commit comments