Skip to content

Commit 186eaf8

Browse files
committed
Fix conversion weight names
1 parent d5767c1 commit 186eaf8

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def eager_attention_forward(
4848

4949
# Apply attention mask if provided
5050
if attention_mask is not None:
51-
# causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]]
5251
attn_weights = ops.add(attn_weights, attention_mask)
5352

5453
attn_weights = ops.softmax(attn_weights, axis=-1)

keras_hub/src/utils/transformers/convert_smollm3.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ def convert_backbone_config(transformers_config):
3030
"rope_layer_enabled_list": transformers_config["no_rope_layers"],
3131
"layer_types": transformers_config["layer_types"],
3232
"mlp_bias": transformers_config["mlp_bias"],
33-
"num_hidden_layers": transformers_config[
34-
"num_hidden_layers"
35-
], # Redundant with num_layers, but kept for completeness
3633
}
3734

3835

@@ -50,41 +47,41 @@ def transpose_and_reshape(x, shape):
5047

5148
# Input layernorm
5249
loader.port_weight(
53-
keras_variable=decoder_layer._self_attention_layernorm.scale,
50+
keras_variable=decoder_layer.input_layernorm.scale,
5451
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
5552
)
5653

5754
# Attention layers
5855

5956
## Query
6057
loader.port_weight(
61-
keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
58+
keras_variable=decoder_layer.self_attn.q_proj.kernel,
6259
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
6360
hook_fn=transpose_and_reshape,
6461
)
6562
loader.port_weight(
66-
keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale,
63+
keras_variable=decoder_layer.self_attn.q_norm.scale,
6764
hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight",
6865
)
6966
## Key
7067
loader.port_weight(
71-
keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
68+
keras_variable=decoder_layer.self_attn.k_proj.kernel,
7269
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
7370
hook_fn=transpose_and_reshape,
7471
)
7572
loader.port_weight(
76-
keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale,
73+
keras_variable=decoder_layer.self_attn.k_norm.scale,
7774
hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight",
7875
)
7976
## Value
8077
loader.port_weight(
81-
keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
78+
keras_variable=decoder_layer.self_attn.v_proj.kernel,
8279
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
8380
hook_fn=transpose_and_reshape,
8481
)
8582
## Output
8683
loader.port_weight(
87-
keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
84+
keras_variable=decoder_layer.self_attn.o_proj.kernel,
8885
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
8986
# rearrange_patterns="c (a b) -> a b c",
9087
# rearrange_dims={"a": backbone.num_query_heads},
@@ -93,27 +90,27 @@ def transpose_and_reshape(x, shape):
9390

9491
# MLP layers
9592
loader.port_weight(
96-
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
93+
keras_variable=decoder_layer.mlp.up_proj.kernel,
9794
hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
9895
# rearrange_patterns="b a -> a b",
9996
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
10097
)
10198
loader.port_weight(
102-
keras_variable=decoder_layer._feedforward_output_dense.kernel,
99+
keras_variable=decoder_layer.mlp.down_proj.kernel,
103100
hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
104101
# rearrange_patterns="b a -> a b",
105102
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
106103
)
107104
loader.port_weight(
108-
keras_variable=decoder_layer._feedforward_gate_dense.kernel,
105+
keras_variable=decoder_layer.mlp.gate_proj.kernel,
109106
hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
110107
# rearrange_patterns="b a -> a b",
111108
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
112109
)
113110

114111
# Feedforward layernorm
115112
loader.port_weight(
116-
keras_variable=decoder_layer._feedforward_layernorm.scale,
113+
keras_variable=decoder_layer.post_attention_layernorm.scale,
117114
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
118115
)
119116

0 commit comments

Comments
 (0)