Skip to content

Commit b0080f2

Browse files
committed
Add convert_smollm3.py and update preset loader
1 parent 81eff73 commit b0080f2

File tree

5 files changed

+167
-11
lines changed

5 files changed

+167
-11
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
rope_layer_enabled_list,
6262
layer_types,
6363
mlp_bias,
64-
rms_norm_epsilon,
6564
layer_norm_epsilon,
6665
max_position_embeddings,
6766
rope_theta,
@@ -89,7 +88,7 @@ def __init__(
8988
layer_idx=i,
9089
intermediate_size=intermediate_dim,
9190
mlp_bias=mlp_bias,
92-
rms_norm_epsilon=rms_norm_epsilon,
91+
rms_norm_epsilon=layer_norm_epsilon,
9392
)
9493
self.decoder_layers.append(layer)
9594

@@ -145,9 +144,6 @@ def get_config(self):
145144
{
146145
"vocabulary_size": self.vocabulary_size,
147146
"num_layers": self.num_layers,
148-
"num_query_heads": self.num_query_heads,
149-
"hidden_dim": self.hidden_dim,
150-
"intermediate_dim": self.intermediate_dim,
151147
}
152148
)
153149
return config

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def __init__(
6868
else True
6969
) # Default to True if index out of bounds
7070

71-
self._attention_interface = eager_attention_forward
72-
7371
def call(
7472
self,
7573
hidden_states,
@@ -113,7 +111,7 @@ def call(
113111
query_states, key_states, cos, sin
114112
)
115113

116-
attn_output, attn_weights = self._attention_interface(
114+
attn_output, attn_weights = eager_attention_forward(
117115
module=self,
118116
query=query_states,
119117
key=key_states,

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def eager_attention_forward(
3434
key,
3535
value,
3636
attention_mask,
37-
scaling: float,
38-
dropout: float = 0.0,
37+
scaling,
38+
dropout=0.0,
39+
training=False,
3940
):
4041
key_states = repeat_kv(key, module.num_key_value_groups)
4142
value_states = repeat_kv(value, module.num_key_value_groups)
@@ -51,7 +52,8 @@ def eager_attention_forward(
5152
attn_weights = ops.add(attn_weights, causal_mask)
5253

5354
attn_weights = ops.softmax(attn_weights, axis=-1)
54-
attn_weights = random.dropout(attn_weights, rate=dropout)
55+
if not training:
56+
attn_weights = random.dropout(attn_weights, rate=dropout)
5557
attn_output = ops.matmul(attn_weights, value_states)
5658
attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3))
5759

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import numpy as np
2+
3+
from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
4+
from keras_hub.src.utils.preset_utils import load_json
5+
6+
backbone_cls = SmolLM3Backbone
7+
8+
9+
def convert_backbone_config(transformers_config):
10+
return {
11+
"vocabulary_size": transformers_config["vocab_size"],
12+
"hidden_dim": transformers_config["hidden_size"],
13+
"num_layers": transformers_config["num_hidden_layers"],
14+
"num_attention_heads": transformers_config["num_attention_heads"],
15+
"num_key_value_heads": transformers_config["num_key_value_heads"],
16+
"intermediate_dim": transformers_config["intermediate_size"],
17+
"layer_norm_epsilon": transformers_config[
18+
"rms_norm_eps"
19+
], # Using rms_norm_eps as layer_norm_epsilon
20+
"max_position_embeddings": transformers_config[
21+
"max_position_embeddings"
22+
],
23+
"rope_theta": transformers_config["rope_theta"],
24+
# partial_rotary_factor is not explicitly in config.json
25+
# but is inherited from the default value in the `_compute_default_rope_parameters()`
26+
# function
27+
"partial_rotary_factor": 1.0,
28+
"attention_bias": transformers_config["attention_bias"],
29+
"attention_dropout": transformers_config["attention_dropout"],
30+
"rope_layer_enabled_list": transformers_config["no_rope_layers"],
31+
"layer_types": transformers_config["layer_types"],
32+
"mlp_bias": transformers_config["mlp_bias"],
33+
"num_hidden_layers": transformers_config[
34+
"num_hidden_layers"
35+
], # Redundant with num_layers, but kept for completeness
36+
}
37+
38+
39+
def convert_weights(backbone, loader, transformers_config):
40+
loader.port_weight(
41+
keras_variable=backbone.get_layer("token_embedding").embeddings,
42+
hf_weight_key="model.embed_tokens.weight",
43+
)
44+
if not backbone.tie_word_embeddings:
45+
loader.port_weight(
46+
keras_variable=backbone.get_layer(
47+
"token_embedding"
48+
).reverse_embeddings,
49+
hf_weight_key="lm_head.weight",
50+
# rearrange_pattern="b a -> a b",
51+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
52+
)
53+
54+
def transpose_and_reshape(x, shape):
55+
return np.reshape(np.transpose(x), shape)
56+
57+
for i in range(backbone.num_layers):
58+
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
59+
60+
# Input layernorm
61+
loader.port_weight(
62+
keras_variable=decoder_layer._self_attention_layernorm.scale,
63+
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
64+
)
65+
66+
# Attention layers
67+
68+
## Query
69+
loader.port_weight(
70+
keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
71+
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
72+
hook_fn=transpose_and_reshape,
73+
)
74+
loader.port_weight(
75+
keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale,
76+
hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight",
77+
)
78+
## Key
79+
loader.port_weight(
80+
keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
81+
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
82+
hook_fn=transpose_and_reshape,
83+
)
84+
loader.port_weight(
85+
keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale,
86+
hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight",
87+
)
88+
## Value
89+
loader.port_weight(
90+
keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
91+
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
92+
hook_fn=transpose_and_reshape,
93+
)
94+
## Output
95+
loader.port_weight(
96+
keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
97+
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
98+
# rearrange_patterns="c (a b) -> a b c",
99+
# rearrange_dims={"a": backbone.num_query_heads},
100+
hook_fn=transpose_and_reshape,
101+
)
102+
103+
# MLP layers
104+
loader.port_weight(
105+
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
106+
hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
107+
# rearrange_patterns="b a -> a b",
108+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
109+
)
110+
loader.port_weight(
111+
keras_variable=decoder_layer._feedforward_output_dense.kernel,
112+
hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
113+
# rearrange_patterns="b a -> a b",
114+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
115+
)
116+
loader.port_weight(
117+
keras_variable=decoder_layer._feedforward_gate_dense.kernel,
118+
hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
119+
# rearrange_patterns="b a -> a b",
120+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
121+
)
122+
123+
# Feedforward layernorm
124+
loader.port_weight(
125+
keras_variable=decoder_layer._feedforward_layernorm.scale,
126+
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
127+
)
128+
129+
# Final normalization layer
130+
loader.port_weight(
131+
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
132+
hf_weight_key="model.norm.weight",
133+
)
134+
135+
return backbone
136+
137+
138+
def convert_tokenizer(cls, preset, **kwargs):
139+
tokenizer_config = load_json(preset, "tokenizer.json")
140+
vocab = tokenizer_config["model"]["vocab"]
141+
merges = tokenizer_config["model"]["merges"]
142+
merges = [" ".join(item) for item in merges]
143+
144+
# Load all special tokens with the exception of "reserved" ones.
145+
special_tokens = set()
146+
for token in tokenizer_config["added_tokens"]:
147+
if not token["content"].startswith("<|reserved_special_token_"):
148+
vocab[token["content"]] = token["id"]
149+
special_tokens.add(token["content"])
150+
151+
kwargs.update(
152+
{
153+
"unsplittable_tokens": list(special_tokens),
154+
}
155+
)
156+
157+
return cls(vocabulary=vocab, merges=merges, **kwargs)

keras_hub/src/utils/transformers/preset_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras_hub.src.utils.transformers import convert_qwen
1818
from keras_hub.src.utils.transformers import convert_qwen3
1919
from keras_hub.src.utils.transformers import convert_qwen_moe
20+
from keras_hub.src.utils.transformers import convert_smollm3
2021
from keras_hub.src.utils.transformers import convert_vit
2122
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
2223

@@ -56,6 +57,8 @@ def __init__(self, preset, config):
5657
self.converter = convert_qwen_moe
5758
elif model_type == "qwen3":
5859
self.converter = convert_qwen3
60+
elif model_type == "smollm3":
61+
self.converter = convert_smollm3
5962
else:
6063
raise ValueError(
6164
"KerasHub has no converter for huggingface/transformers models "

0 commit comments

Comments
 (0)