12
12
13
13
14
14
class SmolLM3Attention (layers .Layer ):
15
+ """
16
+ Multi-head attention layer for SmolLM3 model.
17
+
18
+ Args:
19
+ hidden_size: The hidden size of the attention layer.
20
+ num_attention_heads: The number of attention heads.
21
+ num_key_value_heads: The number of key-value heads.
22
+ attention_bias: Whether to use bias in attention projections.
23
+ attention_dropout: Dropout rate for attention weights.
24
+ rope_layer_enabled_list: List indicating if RoPE is enabled for each layer.
25
+ layer_types: List of layer types.
26
+ layer_idx: Index of the current layer.
27
+ """
28
+
15
29
def __init__ (
16
30
self ,
17
31
hidden_size : int ,
@@ -76,15 +90,25 @@ def call(
76
90
training = False ,
77
91
** kwargs ,
78
92
):
93
+ """
94
+ Forward pass for SmolLM3Attention.
95
+
96
+ Args:
97
+ hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size).
98
+ position_embeddings: Tuple of (cos, sin) tensors for RoPE.
99
+ attention_mask: Attention mask tensor.
100
+ training: Whether the layer is in training mode.
101
+ """
79
102
self .training = training
80
103
81
104
input_shape = ops .shape (hidden_states )[
82
105
:- 1
83
106
] # Exclude last dim (hidden_size)
84
107
85
- hidden_shape = (* input_shape , self .num_attention_heads , self .head_dim )
86
-
87
- query_states = ops .reshape (self .q_proj (hidden_states ), hidden_shape )
108
+ query_states = ops .reshape (
109
+ self .q_proj (hidden_states ),
110
+ (* input_shape , self .num_attention_heads , self .head_dim ),
111
+ )
88
112
query_states = ops .transpose (
89
113
query_states , axes = (0 , 2 , 1 , 3 )
90
114
) # (batch, num_heads, seq_len, head_dim)
@@ -129,8 +153,47 @@ def call(
129
153
130
154
return attn_output , attn_weights
131
155
156
+ def compute_output_shape (self , input_shape ):
157
+ """
158
+ Computes the output shape of the layer.
159
+
160
+ Args:
161
+ input_shape: A list/tuple of shapes for the inputs:
162
+ [hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape]
163
+ - hidden_states_shape: (batch_size, seq_len, hidden_size)
164
+ - position_embeddings_shape_tuple: (cos_shape, sin_shape) where cos_shape/sin_shape is (batch_size, seq_len, head_dim)
165
+ - attention_mask_shape: (batch_size, 1, seq_len, seq_len)
166
+
167
+ Returns:
168
+ A list of output shapes: [output_attn_output_shape, output_attn_weights_shape]
169
+ """
170
+ hidden_states_shape = input_shape [0 ]
171
+
172
+ batch_size = hidden_states_shape [0 ]
173
+ seq_len = hidden_states_shape [1 ]
174
+
175
+ output_attn_output_shape = (batch_size , seq_len , self .hidden_size )
176
+
177
+ output_attn_weights_shape = (
178
+ batch_size ,
179
+ self .num_attention_heads ,
180
+ seq_len ,
181
+ seq_len ,
182
+ )
183
+
184
+ return [output_attn_output_shape , output_attn_weights_shape ]
185
+
132
186
133
187
class SmolLM3MLP (layers .Layer ):
188
+ """
189
+ Multi-layer perceptron (MLP) block for SmolLM3 model.
190
+
191
+ Args:
192
+ hidden_size: The hidden size of the MLP.
193
+ intermediate_size: The intermediate size of the MLP.
194
+ mlp_bias: Whether to use bias in MLP dense layers.
195
+ """
196
+
134
197
def __init__ (
135
198
self , hidden_size : int , intermediate_size : int , mlp_bias : bool , ** kwargs
136
199
):
@@ -150,14 +213,50 @@ def __init__(
150
213
)
151
214
152
215
def call (self , x ):
216
+ """
217
+ Forward pass for SmolLM3MLP.
218
+
219
+ Args:
220
+ x: Input tensor of shape (batch_size, seq_len, hidden_size).
221
+ """
153
222
gate_output = activations .silu (self .gate_proj (x ))
154
223
up_output = self .up_proj (x )
155
224
intermediate_output = gate_output * up_output
156
225
down_proj_output = self .down_proj (intermediate_output )
157
226
return down_proj_output
158
227
228
+ def compute_output_shape (self , input_shape ):
229
+ """
230
+ Computes the output shape of the layer.
231
+
232
+ Args:
233
+ input_shape: The input shape (batch_size, seq_len, hidden_size).
234
+
235
+ Returns:
236
+ The output shape, which is the same as the input shape:
237
+ (batch_size, seq_len, hidden_size).
238
+ """
239
+ return input_shape
240
+
159
241
160
242
class SmolLM3DecoderLayer (layers .Layer ):
243
+ """
244
+ Decoder layer for SmolLM3 model, combining self-attention and MLP.
245
+
246
+ Args:
247
+ hidden_size: The hidden size of the layer.
248
+ num_attention_heads: The number of attention heads.
249
+ num_key_value_heads: The number of key-value heads.
250
+ attention_bias: Whether to use bias in attention projections.
251
+ attention_dropout: Dropout rate for attention weights.
252
+ rope_layer_enabled_list: List indicating if RoPE is enabled for each layer.
253
+ layer_types: List of layer types.
254
+ layer_idx: Index of the current layer.
255
+ intermediate_size: The intermediate size of the MLP.
256
+ mlp_bias: Whether to use bias in MLP dense layers.
257
+ rms_norm_epsilon: Epsilon for RMSNormalization.
258
+ """
259
+
161
260
def __init__ (
162
261
self ,
163
262
hidden_size : int ,
@@ -206,8 +305,25 @@ def __init__(
206
305
self .attention_type = layer_types [layer_idx ]
207
306
208
307
def build (self , input_shape ):
209
- # Build sub-layers
210
- self .self_attn .build (input_shape )
308
+ """
309
+ Builds the sub-layers based on the input shape.
310
+
311
+ Args:
312
+ input_shape: The input shape to the decoder layer
313
+ (batch_size, seq_len, hidden_size).
314
+ """
315
+ # input_shape for SmolLM3DecoderLayer: (batch_size, seq_len, hidden_size)
316
+ batch_size = input_shape [0 ]
317
+ seq_len = input_shape [1 ]
318
+
319
+ head_dim = self .self_attn .head_dim
320
+ pos_emb_shape = (batch_size , seq_len , head_dim )
321
+
322
+ attn_mask_shape = (batch_size , 1 , seq_len , seq_len )
323
+
324
+ self .self_attn .build (
325
+ [input_shape , (pos_emb_shape , pos_emb_shape ), attn_mask_shape ]
326
+ )
211
327
self .mlp .build (input_shape )
212
328
self .input_layernorm .build (input_shape )
213
329
self .post_attention_layernorm .build (input_shape )
@@ -221,15 +337,21 @@ def call(
221
337
training = False ,
222
338
** kwargs ,
223
339
):
340
+ """
341
+ Forward pass for SmolLM3DecoderLayer.
342
+
343
+ Args:
344
+ hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size).
345
+ position_embeddings: Optional tuple of (cos, sin) tensors for RoPE.
346
+ training: Whether the layer is in training mode.
347
+ """
224
348
residual = hidden_states
225
349
hidden_states = self .input_layernorm (hidden_states )
226
350
227
- attention_mask = (
228
- compute_causal_mask (
229
- ops .shape (hidden_states )[0 ],
230
- ops .shape (hidden_states )[1 ],
231
- ops .shape (hidden_states )[1 ],
232
- ),
351
+ attention_mask = compute_causal_mask (
352
+ ops .shape (hidden_states )[0 ],
353
+ ops .shape (hidden_states )[1 ],
354
+ ops .shape (hidden_states )[1 ],
233
355
)
234
356
235
357
# Self Attention
@@ -249,8 +371,32 @@ def call(
249
371
250
372
return hidden_states
251
373
374
+ def compute_output_shape (self , input_shape ):
375
+ """
376
+ Computes the output shape of the layer.
377
+
378
+ Args:
379
+ input_shape: The input shape (batch_size, seq_len, hidden_size).
380
+
381
+ Returns:
382
+ The output shape, which is the same as the input shape:
383
+ (batch_size, seq_len, hidden_size).
384
+ """
385
+ return input_shape
386
+
252
387
253
388
class SmolLM3RotaryEmbedding (layers .Layer ):
389
+ """
390
+ Rotary Position Embedding (RoPE) layer for SmolLM3 model.
391
+
392
+ Args:
393
+ hidden_size: The hidden size of the model.
394
+ num_attention_heads: The number of attention heads.
395
+ max_position_embeddings: The maximum sequence length for position embeddings.
396
+ rope_theta: The theta value for RoPE.
397
+ partial_rotary_factor: The factor for partial rotary embedding.
398
+ """
399
+
254
400
def __init__ (
255
401
self ,
256
402
hidden_size : int ,
@@ -285,6 +431,14 @@ def __init__(
285
431
self .original_inv_freq = self .inv_freq
286
432
287
433
def call (self , x , position_ids ):
434
+ """
435
+ Forward pass for SmolLM3RotaryEmbedding.
436
+
437
+ Args:
438
+ x: Input tensor, typically query or key states.
439
+ Shape can vary, but the last dimension is head_dim.
440
+ position_ids: Tensor of position IDs of shape (batch_size, seq_len).
441
+ """
288
442
inv_freq_expanded = ops .expand_dims (
289
443
ops .expand_dims (self .inv_freq , axis = 0 ), axis = - 1
290
444
)
@@ -309,3 +463,31 @@ def call(self, x, position_ids):
309
463
sin = ops .sin (emb ) * self .attention_scaling
310
464
311
465
return ops .cast (cos , x .dtype ), ops .cast (sin , x .dtype )
466
+
467
+ def compute_output_shape (self , input_shape ):
468
+ """
469
+ Computes the output shape of the layer.
470
+
471
+ Args:
472
+ input_shape: A list/tuple of shapes for the inputs:
473
+ [x_shape, position_ids_shape]
474
+ - x_shape: (batch_size, ..., head_dim)
475
+ - position_ids_shape: (batch_size, seq_len)
476
+
477
+ Returns:
478
+ A list of output shapes for (cos, sin):
479
+ [(batch_size, seq_len, head_dim), (batch_size, seq_len, head_dim)]
480
+ """
481
+ if input_shape [1 ] is not None and len (input_shape [1 ]) >= 2 :
482
+ batch_size = input_shape [1 ][0 ]
483
+ seq_len = input_shape [1 ][1 ]
484
+ else :
485
+ # Fallback if position_ids_shape is None or malformed.
486
+ # In this case, the batch_size and seq_len are unknown.
487
+ batch_size = None
488
+ seq_len = None
489
+
490
+ # The output cos and sin have shape (batch_size, seq_len, head_dim)
491
+ output_shape = (batch_size , seq_len , self .head_dim )
492
+
493
+ return [output_shape , output_shape ]
0 commit comments