1
1
import keras
2
2
3
3
from keras_hub .src .api_export import keras_hub_export
4
- from keras_hub .src .layers .modeling .transformer_layer_utils import (
5
- compute_causal_mask ,
6
- )
7
4
from keras_hub .src .models .backbone import Backbone
8
5
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3DecoderLayer
9
6
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3RotaryEmbedding
@@ -78,7 +75,7 @@ def __init__(
78
75
output_dim = hidden_dim ,
79
76
name = "token_embedding" ,
80
77
)
81
- self .transformer_layers = []
78
+ self .decoder_layers = []
82
79
83
80
for i in range (num_layers ):
84
81
layer = SmolLM3DecoderLayer (
@@ -94,7 +91,7 @@ def __init__(
94
91
mlp_bias = mlp_bias ,
95
92
rms_norm_epsilon = rms_norm_epsilon ,
96
93
)
97
- self .transformer_layers .append (layer )
94
+ self .decoder_layers .append (layer )
98
95
99
96
self .norm = keras .layers .RMSNormalization (
100
97
epsilon = layer_norm_epsilon ,
@@ -117,22 +114,19 @@ def __init__(
117
114
shape = (None ,), dtype = "int32" , name = "position_ids"
118
115
)
119
116
117
+ print ("token id" , token_id_input .shape )
120
118
hidden_states = self .token_embedding (token_id_input )
119
+ print ("hidden states id" , hidden_states .shape )
121
120
position_embeddings = self .rotary_embedding (hidden_states , position_ids )
122
121
123
- for decoder_layer in self .layers [:num_hidden_layers ]:
122
+ for decoder_layer in self .decoder_layers [:num_hidden_layers ]:
124
123
hidden_states = decoder_layer (
125
124
hidden_states ,
126
- attention_mask = compute_causal_mask (
127
- hidden_states .shape [0 ],
128
- hidden_states .shape [1 ],
129
- hidden_states .shape [1 ],
130
- ),
131
125
position_embeddings = position_embeddings ,
132
126
** kwargs ,
133
127
)
134
128
135
- sequence_output = self .layer_norm (x )
129
+ sequence_output = self .layer_norm (hidden_states )
136
130
super ().__init__ (
137
131
inputs = {
138
132
"token_ids" : token_id_input ,
0 commit comments