@@ -168,16 +168,15 @@ def __init__(
168
168
layer_types : list [str ],
169
169
_attn_implementation : str ,
170
170
layer_idx : int ,
171
- intermediate_size : int , # For MLP
172
- mlp_bias : bool , # For MLP
173
- rms_norm_eps : float , # For RMSNorm
171
+ intermediate_size : int ,
172
+ mlp_bias : bool ,
173
+ rms_norm_eps : float ,
174
174
** kwargs ,
175
175
):
176
176
super ().__init__ (** kwargs )
177
177
self .hidden_size = hidden_size
178
- self .layer_idx = layer_idx # Store layer_idx
178
+ self .layer_idx = layer_idx
179
179
180
- # Pass all necessary config parameters to SmolLM3AttentionKeras
181
180
self .self_attn = SmolLM3Attention (
182
181
hidden_size = hidden_size ,
183
182
num_attention_heads = num_attention_heads ,
@@ -221,7 +220,7 @@ def call(
221
220
hidden_states ,
222
221
attention_mask = None ,
223
222
position_embeddings = None ,
224
- training = False , # Keras layers have a 'training' argument in call
223
+ training = False ,
225
224
** kwargs ,
226
225
):
227
226
residual = hidden_states
@@ -231,15 +230,12 @@ def call(
231
230
attn_output , _ = self .self_attn (
232
231
hidden_states = hidden_states ,
233
232
attention_mask = attention_mask ,
234
- position_embeddings = position_embeddings , # Pass position_embeddings
235
- training = training , # Pass training state
233
+ position_embeddings = position_embeddings ,
234
+ training = training ,
236
235
** kwargs ,
237
236
)
238
- hidden_states = ops .add (
239
- residual , attn_output
240
- ) # Add attention output to residual
237
+ hidden_states = ops .add (residual , attn_output )
241
238
242
- # Fully Connected (MLP)
243
239
residual = hidden_states
244
240
hidden_states = self .post_attention_layernorm (hidden_states )
245
241
hidden_states = self .mlp (hidden_states )
0 commit comments