Skip to content

Commit b9e458d

Browse files
committed
add most of smollm3backbone
1 parent 598fd74 commit b9e458d

File tree

2 files changed

+159
-10
lines changed

2 files changed

+159
-10
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
6+
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding
7+
8+
9+
@keras_hub_export(
10+
[
11+
"keras_hub.models.SmolLM3Backbone",
12+
"keras_hub.models.SmolLMBackbone",
13+
]
14+
)
15+
class SmolLM3Backbone(Backbone):
16+
"""
17+
The SmolLM Transformer core architecture with hyperparameters.
18+
19+
This network implements a Transformer-based decoder network,
20+
SmolLM3, as described in the SmolLM3 model architecture.
21+
It includes the embedding lookups and transformer layers.
22+
23+
The default constructor gives a fully customizable, randomly initialized
24+
SmolLM3 model with any number of layers, heads, and embedding
25+
dimensions. To load preset architectures and weights, use the `from_preset`
26+
constructor.
27+
28+
Args:
29+
30+
31+
Examples:
32+
33+
```python
34+
input_data = {
35+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
36+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
37+
}
38+
39+
# Pretrained SmolLM decoder.
40+
model = keras_hub.models.SmolLM3Backbone.from_preset("...")
41+
model(input_data)
42+
43+
# Randomly initialized SmolLM3 decoder with custom config.
44+
model = keras_hub.models.SmolLM3Backbone(
45+
...
46+
)
47+
model(input_data)
48+
```
49+
"""
50+
51+
def __init__(
52+
self,
53+
vocabulary_size,
54+
hidden_dim,
55+
intermediate_dim,
56+
num_layers,
57+
num_attention_heads,
58+
num_key_value_heads,
59+
attention_bias,
60+
attention_dropout,
61+
rope_layer_enabled_list,
62+
layer_types,
63+
mlp_bias,
64+
rms_norm_epsilon,
65+
layer_norm_epsilon,
66+
max_position_embeddings,
67+
rope_theta,
68+
partial_rotary_factor,
69+
**kwargs,
70+
):
71+
# === Layers ===
72+
self.token_embedding = keras.layers.Embedding(
73+
input_dim=vocabulary_size,
74+
output_dim=hidden_dim,
75+
name="token_embedding",
76+
)
77+
self.transformer_layers = []
78+
79+
for i in range(num_layers):
80+
layer = SmolLM3DecoderLayer(
81+
hidden_size=hidden_dim,
82+
num_attention_heads=num_attention_heads,
83+
num_key_value_heads=num_key_value_heads,
84+
attention_bias=attention_bias,
85+
attention_dropout=attention_dropout,
86+
rope_layer_enabled_list=rope_layer_enabled_list,
87+
layer_types=layer_types,
88+
layer_idx=i,
89+
intermediate_size=intermediate_dim,
90+
mlp_bias=mlp_bias,
91+
rms_norm_epsilon=rms_norm_epsilon,
92+
)
93+
self.transformer_layers.append(layer)
94+
95+
self.norm = keras.layers.RMSNormalization(
96+
epsilon=layer_norm_epsilon,
97+
name="sequence_output_layernorm",
98+
)
99+
100+
self.rotary_embedding = SmolLM3RotaryEmbedding(
101+
hidden_size=hidden_dim,
102+
num_attention_heads=num_attention_heads,
103+
max_position_embeddings=max_position_embeddings,
104+
rope_theta=rope_theta,
105+
partial_rotary_factor=partial_rotary_factor,
106+
)
107+
108+
# === Functional Model ===
109+
token_id_input = keras.Input(
110+
shape=(None,), dtype="int32", name="token_ids"
111+
)
112+
padding_mask_input = keras.Input(
113+
shape=(None,), dtype="int32", name="padding_mask"
114+
)
115+
x = self.token_embedding(token_id_input)
116+
position_embeddings = self.rotary_embedding(x)
117+
118+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
119+
hidden_states = decoder_layer(
120+
hidden_states,
121+
attention_mask=#createcausalmask,
122+
position_embeddings=position_embeddings,
123+
**kwargs,
124+
)
125+
126+
sequence_output = self.layer_norm(x)
127+
super().__init__(
128+
inputs={
129+
"token_ids": token_id_input,
130+
"padding_mask": padding_mask_input,
131+
},
132+
outputs=sequence_output,
133+
**kwargs,
134+
)
135+
136+
# === Config ===
137+
self.vocabulary_size = vocabulary_size
138+
self.num_layers = num_layers
139+
140+
141+
def get_config(self):
142+
config = super().get_config()
143+
config.update(
144+
{
145+
"vocabulary_size": self.vocabulary_size,
146+
"num_layers": self.num_layers,
147+
"num_query_heads": self.num_query_heads,
148+
"hidden_dim": self.hidden_dim,
149+
"intermediate_dim": self.intermediate_dim,
150+
}
151+
)
152+
return config
153+

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ def __init__(
1616
num_key_value_heads: int,
1717
attention_bias: bool,
1818
attention_dropout: float,
19-
no_rope_layers: list[bool],
19+
rope_layer_enabled_list: list[bool],
2020
layer_types: list[str],
21-
_attn_implementation: str,
2221
layer_idx: int,
2322
**kwargs,
2423
):
@@ -29,9 +28,8 @@ def __init__(
2928
self.num_key_value_heads = num_key_value_heads
3029
self.attention_bias = attention_bias
3130
self.attention_dropout = attention_dropout
32-
self.no_rope_layers = no_rope_layers
31+
self.rope_layer_enabled_list = rope_layer_enabled_list
3332
self.layer_types = layer_types
34-
self._attn_implementation = _attn_implementation
3533

3634
self.layer_idx = layer_idx
3735

@@ -62,8 +60,8 @@ def __init__(
6260
)
6361

6462
self.use_rope = (
65-
self.no_rope_layers[self.layer_idx]
66-
if self.layer_idx < len(self.no_rope_layers)
63+
self.rope_layer_enabled_list[self.layer_idx]
64+
if self.layer_idx < len(self.rope_layer_enabled_list)
6765
else True
6866
) # Default to True if index out of bounds
6967

@@ -166,9 +164,8 @@ def __init__(
166164
num_key_value_heads: int,
167165
attention_bias: bool,
168166
attention_dropout: float,
169-
no_rope_layers: list[bool],
167+
rope_layer_enabled_list: list[bool],
170168
layer_types: list[str],
171-
_attn_implementation: str,
172169
layer_idx: int,
173170
intermediate_size: int,
174171
mlp_bias: bool,
@@ -185,9 +182,8 @@ def __init__(
185182
num_key_value_heads=num_key_value_heads,
186183
attention_bias=attention_bias,
187184
attention_dropout=attention_dropout,
188-
no_rope_layers=no_rope_layers,
185+
rope_layer_enabled_list=rope_layer_enabled_list,
189186
layer_types=layer_types,
190-
_attn_implementation=_attn_implementation,
191187
layer_idx=layer_idx,
192188
name="self_attn",
193189
)

0 commit comments

Comments
 (0)