@@ -30,9 +30,6 @@ def convert_backbone_config(transformers_config):
30
30
"rope_layer_enabled_list" : transformers_config ["no_rope_layers" ],
31
31
"layer_types" : transformers_config ["layer_types" ],
32
32
"mlp_bias" : transformers_config ["mlp_bias" ],
33
- "num_hidden_layers" : transformers_config [
34
- "num_hidden_layers"
35
- ], # Redundant with num_layers, but kept for completeness
36
33
}
37
34
38
35
@@ -50,41 +47,41 @@ def transpose_and_reshape(x, shape):
50
47
51
48
# Input layernorm
52
49
loader .port_weight (
53
- keras_variable = decoder_layer ._self_attention_layernorm .scale ,
50
+ keras_variable = decoder_layer .input_layernorm .scale ,
54
51
hf_weight_key = f"model.layers.{ i } .input_layernorm.weight" ,
55
52
)
56
53
57
54
# Attention layers
58
55
59
56
## Query
60
57
loader .port_weight (
61
- keras_variable = decoder_layer ._self_attention_layer . _query_dense .kernel ,
58
+ keras_variable = decoder_layer .self_attn . q_proj .kernel ,
62
59
hf_weight_key = f"model.layers.{ i } .self_attn.q_proj.weight" ,
63
60
hook_fn = transpose_and_reshape ,
64
61
)
65
62
loader .port_weight (
66
- keras_variable = decoder_layer ._self_attention_layer . _query_dense_layer_norm .scale ,
63
+ keras_variable = decoder_layer .self_attn . q_norm .scale ,
67
64
hf_weight_key = f"model.layers.{ i } .self_attn.q_norm.weight" ,
68
65
)
69
66
## Key
70
67
loader .port_weight (
71
- keras_variable = decoder_layer ._self_attention_layer . _key_dense .kernel ,
68
+ keras_variable = decoder_layer .self_attn . k_proj .kernel ,
72
69
hf_weight_key = f"model.layers.{ i } .self_attn.k_proj.weight" ,
73
70
hook_fn = transpose_and_reshape ,
74
71
)
75
72
loader .port_weight (
76
- keras_variable = decoder_layer ._self_attention_layer . _key_dense_layer_norm .scale ,
73
+ keras_variable = decoder_layer .self_attn . k_norm .scale ,
77
74
hf_weight_key = f"model.layers.{ i } .self_attn.k_norm.weight" ,
78
75
)
79
76
## Value
80
77
loader .port_weight (
81
- keras_variable = decoder_layer ._self_attention_layer . _value_dense .kernel ,
78
+ keras_variable = decoder_layer .self_attn . v_proj .kernel ,
82
79
hf_weight_key = f"model.layers.{ i } .self_attn.v_proj.weight" ,
83
80
hook_fn = transpose_and_reshape ,
84
81
)
85
82
## Output
86
83
loader .port_weight (
87
- keras_variable = decoder_layer ._self_attention_layer . _output_dense .kernel ,
84
+ keras_variable = decoder_layer .self_attn . o_proj .kernel ,
88
85
hf_weight_key = f"model.layers.{ i } .self_attn.o_proj.weight" ,
89
86
# rearrange_patterns="c (a b) -> a b c",
90
87
# rearrange_dims={"a": backbone.num_query_heads},
@@ -93,27 +90,27 @@ def transpose_and_reshape(x, shape):
93
90
94
91
# MLP layers
95
92
loader .port_weight (
96
- keras_variable = decoder_layer ._feedforward_intermediate_dense .kernel ,
93
+ keras_variable = decoder_layer .mlp . up_proj .kernel ,
97
94
hf_weight_key = f"model.layers.{ i } .mlp.up_proj.weight" ,
98
95
# rearrange_patterns="b a -> a b",
99
96
hook_fn = lambda hf_tensor , _ : np .transpose (hf_tensor , axes = (1 , 0 )),
100
97
)
101
98
loader .port_weight (
102
- keras_variable = decoder_layer ._feedforward_output_dense .kernel ,
99
+ keras_variable = decoder_layer .mlp . down_proj .kernel ,
103
100
hf_weight_key = f"model.layers.{ i } .mlp.down_proj.weight" ,
104
101
# rearrange_patterns="b a -> a b",
105
102
hook_fn = lambda hf_tensor , _ : np .transpose (hf_tensor , axes = (1 , 0 )),
106
103
)
107
104
loader .port_weight (
108
- keras_variable = decoder_layer ._feedforward_gate_dense .kernel ,
105
+ keras_variable = decoder_layer .mlp . gate_proj .kernel ,
109
106
hf_weight_key = f"model.layers.{ i } .mlp.gate_proj.weight" ,
110
107
# rearrange_patterns="b a -> a b",
111
108
hook_fn = lambda hf_tensor , _ : np .transpose (hf_tensor , axes = (1 , 0 )),
112
109
)
113
110
114
111
# Feedforward layernorm
115
112
loader .port_weight (
116
- keras_variable = decoder_layer ._feedforward_layernorm .scale ,
113
+ keras_variable = decoder_layer .post_attention_layernorm .scale ,
117
114
hf_weight_key = f"model.layers.{ i } .post_attention_layernorm.weight" ,
118
115
)
119
116
0 commit comments