Skip to content

Commit d5767c1

Browse files
committed
Fix causal mask call
1 parent b0080f2 commit d5767c1

File tree

4 files changed

+202
-27
lines changed

4 files changed

+202
-27
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
intermediate_size=intermediate_dim,
9090
mlp_bias=mlp_bias,
9191
rms_norm_epsilon=layer_norm_epsilon,
92+
name=f"transformer_layer_{i}",
9293
)
9394
self.decoder_layers.append(layer)
9495

@@ -109,14 +110,14 @@ def __init__(
109110
token_id_input = keras.Input(
110111
shape=(None,), dtype="int32", name="token_ids"
111112
)
112-
position_ids = keras.Input(
113+
position_id_input = keras.Input(
113114
shape=(None,), dtype="int32", name="position_ids"
114115
)
115116

116-
print("token id", token_id_input.shape)
117117
hidden_states = self.token_embedding(token_id_input)
118-
print("hidden states id", hidden_states.shape)
119-
position_embeddings = self.rotary_embedding(hidden_states, position_ids)
118+
position_embeddings = self.rotary_embedding(
119+
hidden_states, position_id_input
120+
)
120121

121122
for decoder_layer in self.decoder_layers[:num_hidden_layers]:
122123
hidden_states = decoder_layer(
@@ -125,10 +126,11 @@ def __init__(
125126
**kwargs,
126127
)
127128

128-
sequence_output = self.layer_norm(hidden_states)
129+
sequence_output = self.norm(hidden_states)
129130
super().__init__(
130131
inputs={
131132
"token_ids": token_id_input,
133+
"position_ids": position_id_input,
132134
},
133135
outputs=sequence_output,
134136
**kwargs,

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 193 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@
1212

1313

1414
class SmolLM3Attention(layers.Layer):
15+
"""
16+
Multi-head attention layer for SmolLM3 model.
17+
18+
Args:
19+
hidden_size: The hidden size of the attention layer.
20+
num_attention_heads: The number of attention heads.
21+
num_key_value_heads: The number of key-value heads.
22+
attention_bias: Whether to use bias in attention projections.
23+
attention_dropout: Dropout rate for attention weights.
24+
rope_layer_enabled_list: List indicating if RoPE is enabled for each layer.
25+
layer_types: List of layer types.
26+
layer_idx: Index of the current layer.
27+
"""
28+
1529
def __init__(
1630
self,
1731
hidden_size: int,
@@ -76,15 +90,25 @@ def call(
7690
training=False,
7791
**kwargs,
7892
):
93+
"""
94+
Forward pass for SmolLM3Attention.
95+
96+
Args:
97+
hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size).
98+
position_embeddings: Tuple of (cos, sin) tensors for RoPE.
99+
attention_mask: Attention mask tensor.
100+
training: Whether the layer is in training mode.
101+
"""
79102
self.training = training
80103

81104
input_shape = ops.shape(hidden_states)[
82105
:-1
83106
] # Exclude last dim (hidden_size)
84107

85-
hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim)
86-
87-
query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape)
108+
query_states = ops.reshape(
109+
self.q_proj(hidden_states),
110+
(*input_shape, self.num_attention_heads, self.head_dim),
111+
)
88112
query_states = ops.transpose(
89113
query_states, axes=(0, 2, 1, 3)
90114
) # (batch, num_heads, seq_len, head_dim)
@@ -129,8 +153,47 @@ def call(
129153

130154
return attn_output, attn_weights
131155

156+
def compute_output_shape(self, input_shape):
157+
"""
158+
Computes the output shape of the layer.
159+
160+
Args:
161+
input_shape: A list/tuple of shapes for the inputs:
162+
[hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape]
163+
- hidden_states_shape: (batch_size, seq_len, hidden_size)
164+
- position_embeddings_shape_tuple: (cos_shape, sin_shape) where cos_shape/sin_shape is (batch_size, seq_len, head_dim)
165+
- attention_mask_shape: (batch_size, 1, seq_len, seq_len)
166+
167+
Returns:
168+
A list of output shapes: [output_attn_output_shape, output_attn_weights_shape]
169+
"""
170+
hidden_states_shape = input_shape[0]
171+
172+
batch_size = hidden_states_shape[0]
173+
seq_len = hidden_states_shape[1]
174+
175+
output_attn_output_shape = (batch_size, seq_len, self.hidden_size)
176+
177+
output_attn_weights_shape = (
178+
batch_size,
179+
self.num_attention_heads,
180+
seq_len,
181+
seq_len,
182+
)
183+
184+
return [output_attn_output_shape, output_attn_weights_shape]
185+
132186

133187
class SmolLM3MLP(layers.Layer):
188+
"""
189+
Multi-layer perceptron (MLP) block for SmolLM3 model.
190+
191+
Args:
192+
hidden_size: The hidden size of the MLP.
193+
intermediate_size: The intermediate size of the MLP.
194+
mlp_bias: Whether to use bias in MLP dense layers.
195+
"""
196+
134197
def __init__(
135198
self, hidden_size: int, intermediate_size: int, mlp_bias: bool, **kwargs
136199
):
@@ -150,14 +213,50 @@ def __init__(
150213
)
151214

152215
def call(self, x):
216+
"""
217+
Forward pass for SmolLM3MLP.
218+
219+
Args:
220+
x: Input tensor of shape (batch_size, seq_len, hidden_size).
221+
"""
153222
gate_output = activations.silu(self.gate_proj(x))
154223
up_output = self.up_proj(x)
155224
intermediate_output = gate_output * up_output
156225
down_proj_output = self.down_proj(intermediate_output)
157226
return down_proj_output
158227

228+
def compute_output_shape(self, input_shape):
229+
"""
230+
Computes the output shape of the layer.
231+
232+
Args:
233+
input_shape: The input shape (batch_size, seq_len, hidden_size).
234+
235+
Returns:
236+
The output shape, which is the same as the input shape:
237+
(batch_size, seq_len, hidden_size).
238+
"""
239+
return input_shape
240+
159241

160242
class SmolLM3DecoderLayer(layers.Layer):
243+
"""
244+
Decoder layer for SmolLM3 model, combining self-attention and MLP.
245+
246+
Args:
247+
hidden_size: The hidden size of the layer.
248+
num_attention_heads: The number of attention heads.
249+
num_key_value_heads: The number of key-value heads.
250+
attention_bias: Whether to use bias in attention projections.
251+
attention_dropout: Dropout rate for attention weights.
252+
rope_layer_enabled_list: List indicating if RoPE is enabled for each layer.
253+
layer_types: List of layer types.
254+
layer_idx: Index of the current layer.
255+
intermediate_size: The intermediate size of the MLP.
256+
mlp_bias: Whether to use bias in MLP dense layers.
257+
rms_norm_epsilon: Epsilon for RMSNormalization.
258+
"""
259+
161260
def __init__(
162261
self,
163262
hidden_size: int,
@@ -206,8 +305,25 @@ def __init__(
206305
self.attention_type = layer_types[layer_idx]
207306

208307
def build(self, input_shape):
209-
# Build sub-layers
210-
self.self_attn.build(input_shape)
308+
"""
309+
Builds the sub-layers based on the input shape.
310+
311+
Args:
312+
input_shape: The input shape to the decoder layer
313+
(batch_size, seq_len, hidden_size).
314+
"""
315+
# input_shape for SmolLM3DecoderLayer: (batch_size, seq_len, hidden_size)
316+
batch_size = input_shape[0]
317+
seq_len = input_shape[1]
318+
319+
head_dim = self.self_attn.head_dim
320+
pos_emb_shape = (batch_size, seq_len, head_dim)
321+
322+
attn_mask_shape = (batch_size, 1, seq_len, seq_len)
323+
324+
self.self_attn.build(
325+
[input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape]
326+
)
211327
self.mlp.build(input_shape)
212328
self.input_layernorm.build(input_shape)
213329
self.post_attention_layernorm.build(input_shape)
@@ -221,15 +337,21 @@ def call(
221337
training=False,
222338
**kwargs,
223339
):
340+
"""
341+
Forward pass for SmolLM3DecoderLayer.
342+
343+
Args:
344+
hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size).
345+
position_embeddings: Optional tuple of (cos, sin) tensors for RoPE.
346+
training: Whether the layer is in training mode.
347+
"""
224348
residual = hidden_states
225349
hidden_states = self.input_layernorm(hidden_states)
226350

227-
attention_mask = (
228-
compute_causal_mask(
229-
ops.shape(hidden_states)[0],
230-
ops.shape(hidden_states)[1],
231-
ops.shape(hidden_states)[1],
232-
),
351+
attention_mask = compute_causal_mask(
352+
ops.shape(hidden_states)[0],
353+
ops.shape(hidden_states)[1],
354+
ops.shape(hidden_states)[1],
233355
)
234356

235357
# Self Attention
@@ -249,8 +371,32 @@ def call(
249371

250372
return hidden_states
251373

374+
def compute_output_shape(self, input_shape):
375+
"""
376+
Computes the output shape of the layer.
377+
378+
Args:
379+
input_shape: The input shape (batch_size, seq_len, hidden_size).
380+
381+
Returns:
382+
The output shape, which is the same as the input shape:
383+
(batch_size, seq_len, hidden_size).
384+
"""
385+
return input_shape
386+
252387

253388
class SmolLM3RotaryEmbedding(layers.Layer):
389+
"""
390+
Rotary Position Embedding (RoPE) layer for SmolLM3 model.
391+
392+
Args:
393+
hidden_size: The hidden size of the model.
394+
num_attention_heads: The number of attention heads.
395+
max_position_embeddings: The maximum sequence length for position embeddings.
396+
rope_theta: The theta value for RoPE.
397+
partial_rotary_factor: The factor for partial rotary embedding.
398+
"""
399+
254400
def __init__(
255401
self,
256402
hidden_size: int,
@@ -285,6 +431,14 @@ def __init__(
285431
self.original_inv_freq = self.inv_freq
286432

287433
def call(self, x, position_ids):
434+
"""
435+
Forward pass for SmolLM3RotaryEmbedding.
436+
437+
Args:
438+
x: Input tensor, typically query or key states.
439+
Shape can vary, but the last dimension is head_dim.
440+
position_ids: Tensor of position IDs of shape (batch_size, seq_len).
441+
"""
288442
inv_freq_expanded = ops.expand_dims(
289443
ops.expand_dims(self.inv_freq, axis=0), axis=-1
290444
)
@@ -309,3 +463,31 @@ def call(self, x, position_ids):
309463
sin = ops.sin(emb) * self.attention_scaling
310464

311465
return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype)
466+
467+
def compute_output_shape(self, input_shape):
468+
"""
469+
Computes the output shape of the layer.
470+
471+
Args:
472+
input_shape: A list/tuple of shapes for the inputs:
473+
[x_shape, position_ids_shape]
474+
- x_shape: (batch_size, ..., head_dim)
475+
- position_ids_shape: (batch_size, seq_len)
476+
477+
Returns:
478+
A list of output shapes for (cos, sin):
479+
[(batch_size, seq_len, head_dim), (batch_size, seq_len, head_dim)]
480+
"""
481+
if input_shape[1] is not None and len(input_shape[1]) >= 2:
482+
batch_size = input_shape[1][0]
483+
seq_len = input_shape[1][1]
484+
else:
485+
# Fallback if position_ids_shape is None or malformed.
486+
# In this case, the batch_size and seq_len are unknown.
487+
batch_size = None
488+
seq_len = None
489+
490+
# The output cos and sin have shape (batch_size, seq_len, head_dim)
491+
output_shape = (batch_size, seq_len, self.head_dim)
492+
493+
return [output_shape, output_shape]

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def eager_attention_forward(
4848

4949
# Apply attention mask if provided
5050
if attention_mask is not None:
51-
causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]]
52-
attn_weights = ops.add(attn_weights, causal_mask)
51+
# causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]]
52+
attn_weights = ops.add(attn_weights, attention_mask)
5353

5454
attn_weights = ops.softmax(attn_weights, axis=-1)
5555
if not training:

keras_hub/src/utils/transformers/convert_smollm3.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,6 @@ def convert_weights(backbone, loader, transformers_config):
4141
keras_variable=backbone.get_layer("token_embedding").embeddings,
4242
hf_weight_key="model.embed_tokens.weight",
4343
)
44-
if not backbone.tie_word_embeddings:
45-
loader.port_weight(
46-
keras_variable=backbone.get_layer(
47-
"token_embedding"
48-
).reverse_embeddings,
49-
hf_weight_key="lm_head.weight",
50-
# rearrange_pattern="b a -> a b",
51-
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
52-
)
5344

5445
def transpose_and_reshape(x, shape):
5546
return np.reshape(np.transpose(x), shape)

0 commit comments

Comments
 (0)