@@ -82,6 +82,23 @@ def __init__(
82
82
else True
83
83
) # Default to True if index out of bounds
84
84
85
+ def build (self , input_shape ):
86
+ """
87
+ Builds the internal Dense layers.
88
+ Args:
89
+ input_shape: A list/tuple of shapes for the inputs:
90
+ [hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape]
91
+ - hidden_states_shape: (batch_size, seq_len, hidden_size)
92
+ """
93
+ # The input shape to the Dense layers (q_proj, k_proj, v_proj, o_proj)
94
+ # is the same as the hidden_states input to SmolLM3Attention.
95
+ hidden_states_shape = input_shape [0 ]
96
+ self .q_proj .build (hidden_states_shape )
97
+ self .k_proj .build (hidden_states_shape )
98
+ self .v_proj .build (hidden_states_shape )
99
+ self .o_proj .build (hidden_states_shape )
100
+ super ().build (input_shape )
101
+
85
102
def call (
86
103
self ,
87
104
hidden_states ,
@@ -212,6 +229,25 @@ def __init__(
212
229
self .hidden_size , use_bias = self .mlp_bias , name = "down_proj"
213
230
)
214
231
232
+ def build (self , input_shape ):
233
+ """
234
+ Builds the internal Dense layers.
235
+ Args:
236
+ input_shape: The shape of the input to this layer
237
+ (batch_size, seq_len, hidden_size).
238
+ """
239
+ self .gate_proj .build (input_shape )
240
+ self .up_proj .build (input_shape )
241
+ # The down_proj takes intermediate_output, which has shape
242
+ # (batch_size, seq_len, intermediate_size)
243
+ down_proj_input_shape = (
244
+ input_shape [0 ],
245
+ input_shape [1 ],
246
+ self .intermediate_size ,
247
+ )
248
+ self .down_proj .build (down_proj_input_shape )
249
+ super ().build (input_shape )
250
+
215
251
def call (self , x ):
216
252
"""
217
253
Forward pass for SmolLM3MLP.
@@ -321,9 +357,13 @@ def build(self, input_shape):
321
357
322
358
attn_mask_shape = (batch_size , 1 , seq_len , seq_len )
323
359
360
+ # Pass the correct input shape to self_attn's build method
361
+ # The input_shape for self_attn.build is a list:
362
+ # [hidden_states_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
324
363
self .self_attn .build (
325
364
[input_shape , (pos_emb_shape , pos_emb_shape ), attn_mask_shape ]
326
365
)
366
+
327
367
self .mlp .build (input_shape )
328
368
self .input_layernorm .build (input_shape )
329
369
self .post_attention_layernorm .build (input_shape )
@@ -430,7 +470,24 @@ def __init__(
430
470
)
431
471
self .original_inv_freq = self .inv_freq
432
472
433
- def call (self , x , position_ids ):
473
+ def build (self , input_shape ):
474
+ """
475
+ Builds the layer. For SmolLM3RotaryEmbedding, this mainly ensures
476
+ that the parent layer's build is called.
477
+ Args:
478
+ input_shape: A list/tuple of shapes for the inputs:
479
+ [x_shape, position_ids_shape]
480
+ - x_shape: (batch_size, ..., head_dim)
481
+ - position_ids_shape: (batch_size, seq_len)
482
+ """
483
+ # No internal layers to explicitly build here, as inv_freq is added in __init__
484
+ super ().build (input_shape )
485
+
486
+ def call (
487
+ self ,
488
+ x ,
489
+ position_ids ,
490
+ ):
434
491
"""
435
492
Forward pass for SmolLM3RotaryEmbedding.
436
493
0 commit comments