Skip to content

Commit 6819fd1

Browse files
committed
Build all layers
1 parent 6ab2e5c commit 6819fd1

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ def __init__(
8282
else True
8383
) # Default to True if index out of bounds
8484

85+
def build(self, input_shape):
86+
"""
87+
Builds the internal Dense layers.
88+
Args:
89+
input_shape: A list/tuple of shapes for the inputs:
90+
[hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape]
91+
- hidden_states_shape: (batch_size, seq_len, hidden_size)
92+
"""
93+
# The input shape to the Dense layers (q_proj, k_proj, v_proj, o_proj)
94+
# is the same as the hidden_states input to SmolLM3Attention.
95+
hidden_states_shape = input_shape[0]
96+
self.q_proj.build(hidden_states_shape)
97+
self.k_proj.build(hidden_states_shape)
98+
self.v_proj.build(hidden_states_shape)
99+
self.o_proj.build(hidden_states_shape)
100+
super().build(input_shape)
101+
85102
def call(
86103
self,
87104
hidden_states,
@@ -212,6 +229,25 @@ def __init__(
212229
self.hidden_size, use_bias=self.mlp_bias, name="down_proj"
213230
)
214231

232+
def build(self, input_shape):
233+
"""
234+
Builds the internal Dense layers.
235+
Args:
236+
input_shape: The shape of the input to this layer
237+
(batch_size, seq_len, hidden_size).
238+
"""
239+
self.gate_proj.build(input_shape)
240+
self.up_proj.build(input_shape)
241+
# The down_proj takes intermediate_output, which has shape
242+
# (batch_size, seq_len, intermediate_size)
243+
down_proj_input_shape = (
244+
input_shape[0],
245+
input_shape[1],
246+
self.intermediate_size,
247+
)
248+
self.down_proj.build(down_proj_input_shape)
249+
super().build(input_shape)
250+
215251
def call(self, x):
216252
"""
217253
Forward pass for SmolLM3MLP.
@@ -321,9 +357,13 @@ def build(self, input_shape):
321357

322358
attn_mask_shape = (batch_size, 1, seq_len, seq_len)
323359

360+
# Pass the correct input shape to self_attn's build method
361+
# The input_shape for self_attn.build is a list:
362+
# [hidden_states_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
324363
self.self_attn.build(
325364
[input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
326365
)
366+
327367
self.mlp.build(input_shape)
328368
self.input_layernorm.build(input_shape)
329369
self.post_attention_layernorm.build(input_shape)
@@ -430,7 +470,24 @@ def __init__(
430470
)
431471
self.original_inv_freq = self.inv_freq
432472

433-
def call(self, x, position_ids):
473+
def build(self, input_shape):
474+
"""
475+
Builds the layer. For SmolLM3RotaryEmbedding, this mainly ensures
476+
that the parent layer's build is called.
477+
Args:
478+
input_shape: A list/tuple of shapes for the inputs:
479+
[x_shape, position_ids_shape]
480+
- x_shape: (batch_size, ..., head_dim)
481+
- position_ids_shape: (batch_size, seq_len)
482+
"""
483+
# No internal layers to explicitly build here, as inv_freq is added in __init__
484+
super().build(input_shape)
485+
486+
def call(
487+
self,
488+
x,
489+
position_ids,
490+
):
434491
"""
435492
Forward pass for SmolLM3RotaryEmbedding.
436493

0 commit comments

Comments
 (0)