1
1
import keras
2
2
3
3
from keras_hub .src .api_export import keras_hub_export
4
+ from keras_hub .src .layers .modeling .reversible_embedding import (
5
+ ReversibleEmbedding ,
6
+ )
4
7
from keras_hub .src .models .backbone import Backbone
5
8
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3DecoderLayer
6
9
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3RotaryEmbedding
@@ -68,12 +71,12 @@ def __init__(
68
71
** kwargs ,
69
72
):
70
73
# === Layers ===
71
- self .token_embedding = keras . layers . Embedding (
74
+ self .token_embedding = ReversibleEmbedding (
72
75
input_dim = vocabulary_size ,
73
76
output_dim = hidden_dim ,
74
77
name = "token_embedding" ,
75
78
)
76
- self .decoder_layers = []
79
+ self .transformer_layers = []
77
80
78
81
for i in range (num_layers ):
79
82
layer = SmolLM3DecoderLayer (
@@ -87,10 +90,10 @@ def __init__(
87
90
layer_idx = i ,
88
91
intermediate_size = intermediate_dim ,
89
92
mlp_bias = mlp_bias ,
90
- rms_norm_epsilon = layer_norm_epsilon ,
93
+ layer_norm_epsilon = layer_norm_epsilon ,
91
94
name = f"transformer_layer_{ i } " ,
92
95
)
93
- self .decoder_layers .append (layer )
96
+ self .transformer_layers .append (layer )
94
97
95
98
self .norm = keras .layers .RMSNormalization (
96
99
epsilon = layer_norm_epsilon ,
@@ -112,16 +115,20 @@ def __init__(
112
115
position_id_input = keras .Input (
113
116
shape = (None ,), dtype = "int32" , name = "position_ids"
114
117
)
118
+ padding_mask_input = keras .Input (
119
+ shape = (None ,), dtype = "int32" , name = "padding_mask"
120
+ )
115
121
116
122
hidden_states = self .token_embedding (token_id_input )
117
123
position_embeddings = self .rotary_embedding (
118
124
hidden_states , position_id_input
119
125
)
120
126
121
- for decoder_layer in self .decoder_layers [:num_layers ]:
127
+ for decoder_layer in self .transformer_layers [:num_layers ]:
122
128
hidden_states = decoder_layer (
123
129
hidden_states ,
124
130
position_embeddings = position_embeddings ,
131
+ decoder_padding_mask = padding_mask_input ,
125
132
** kwargs ,
126
133
)
127
134
@@ -130,21 +137,48 @@ def __init__(
130
137
inputs = {
131
138
"token_ids" : token_id_input ,
132
139
"position_ids" : position_id_input ,
140
+ "padding_mask" : padding_mask_input ,
133
141
},
134
142
outputs = sequence_output ,
135
143
** kwargs ,
136
144
)
137
145
138
146
# === Config ===
139
147
self .vocabulary_size = vocabulary_size
148
+ self .hidden_dim = hidden_dim
149
+ self .intermediate_dim = intermediate_dim
140
150
self .num_layers = num_layers
151
+ self .num_attention_heads = num_attention_heads
152
+ self .num_key_value_heads = num_key_value_heads
153
+ self .attention_bias = attention_bias
154
+ self .attention_dropout = attention_dropout
155
+ self .rope_layer_enabled_list = rope_layer_enabled_list
156
+ self .layer_types = layer_types
157
+ self .mlp_bias = mlp_bias
158
+ self .layer_norm_epsilon = layer_norm_epsilon
159
+ self .max_position_embeddings = max_position_embeddings
160
+ self .rope_theta = rope_theta
161
+ self .partial_rotary_factor = partial_rotary_factor
141
162
142
163
def get_config (self ):
143
164
config = super ().get_config ()
144
165
config .update (
145
166
{
146
167
"vocabulary_size" : self .vocabulary_size ,
168
+ "hidden_dim" : self .hidden_dim ,
169
+ "intermediate_dim" : self .intermediate_dim ,
147
170
"num_layers" : self .num_layers ,
171
+ "num_attention_heads" : self .num_attention_heads ,
172
+ "num_key_value_heads" : self .num_key_value_heads ,
173
+ "attention_bias" : self .attention_bias ,
174
+ "attention_dropout" : self .attention_dropout ,
175
+ "rope_layer_enabled_list" : self .rope_layer_enabled_list ,
176
+ "layer_types" : self .layer_types ,
177
+ "mlp_bias" : self .mlp_bias ,
178
+ "layer_norm_epsilon" : self .layer_norm_epsilon ,
179
+ "max_position_embeddings" : self .max_position_embeddings ,
180
+ "rope_theta" : self .rope_theta ,
181
+ "partial_rotary_factor" : self .partial_rotary_factor ,
148
182
}
149
183
)
150
184
return config
0 commit comments