Skip to content

Commit 2448d80

Browse files
committed
remove unnecessary comments
1 parent 1369733 commit 2448d80

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,15 @@ def __init__(
168168
layer_types: list[str],
169169
_attn_implementation: str,
170170
layer_idx: int,
171-
intermediate_size: int, # For MLP
172-
mlp_bias: bool, # For MLP
173-
rms_norm_eps: float, # For RMSNorm
171+
intermediate_size: int,
172+
mlp_bias: bool,
173+
rms_norm_eps: float,
174174
**kwargs,
175175
):
176176
super().__init__(**kwargs)
177177
self.hidden_size = hidden_size
178-
self.layer_idx = layer_idx # Store layer_idx
178+
self.layer_idx = layer_idx
179179

180-
# Pass all necessary config parameters to SmolLM3AttentionKeras
181180
self.self_attn = SmolLM3Attention(
182181
hidden_size=hidden_size,
183182
num_attention_heads=num_attention_heads,
@@ -221,7 +220,7 @@ def call(
221220
hidden_states,
222221
attention_mask=None,
223222
position_embeddings=None,
224-
training=False, # Keras layers have a 'training' argument in call
223+
training=False,
225224
**kwargs,
226225
):
227226
residual = hidden_states
@@ -231,15 +230,12 @@ def call(
231230
attn_output, _ = self.self_attn(
232231
hidden_states=hidden_states,
233232
attention_mask=attention_mask,
234-
position_embeddings=position_embeddings, # Pass position_embeddings
235-
training=training, # Pass training state
233+
position_embeddings=position_embeddings,
234+
training=training,
236235
**kwargs,
237236
)
238-
hidden_states = ops.add(
239-
residual, attn_output
240-
) # Add attention output to residual
237+
hidden_states = ops.add(residual, attn_output)
241238

242-
# Fully Connected (MLP)
243239
residual = hidden_states
244240
hidden_states = self.post_attention_layernorm(hidden_states)
245241
hidden_states = self.mlp(hidden_states)

0 commit comments

Comments
 (0)