8
8
from keras_hub .src .models .backbone import Backbone
9
9
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3DecoderLayer
10
10
from keras_hub .src .models .smollm3 .smollm3_layers import SmolLM3RotaryEmbedding
11
- from keras_hub .src .layers .modeling .rotary_embedding import RotaryEmbedding
12
11
13
12
14
13
@keras_hub_export (
@@ -70,7 +69,6 @@ def __init__(
70
69
max_position_embeddings ,
71
70
rope_theta ,
72
71
partial_rotary_factor ,
73
- rope_scaling = 1 ,
74
72
** kwargs ,
75
73
):
76
74
# === Layers ===
@@ -102,17 +100,12 @@ def __init__(
102
100
name = "sequence_output_layernorm" ,
103
101
)
104
102
105
- #self.rotary_embedding = SmolLM3RotaryEmbedding(
106
- # hidden_size=hidden_dim,
107
- # num_attention_heads=num_attention_heads,
108
- # max_position_embeddings=max_position_embeddings,
109
- # rope_theta=rope_theta,
110
- # partial_rotary_factor=partial_rotary_factor,
111
- #)
112
- self .rotary_embedding = RotaryEmbedding (
113
- max_wavelength = rope_theta ,
114
- scaling_factor = rope_scaling ,
115
- dtype = self .token_embedding .dtype_policy
103
+ self .rotary_embedding = SmolLM3RotaryEmbedding (
104
+ hidden_size = hidden_dim ,
105
+ num_attention_heads = num_attention_heads ,
106
+ max_position_embeddings = max_position_embeddings ,
107
+ rope_theta = rope_theta ,
108
+ partial_rotary_factor = partial_rotary_factor ,
116
109
)
117
110
118
111
# === Functional Model ===
@@ -124,8 +117,14 @@ def __init__(
124
117
shape = (None ,), dtype = "int32" , name = "padding_mask"
125
118
)
126
119
120
+ cache_update_index = kwargs .get ('self_attention_cache_index' )
121
+
122
+ start_index = (
123
+ cache_update_index if cache_update_index is not None else 0
124
+ )
125
+
127
126
hidden_states = self .token_embedding (token_id_input )
128
- position_embeddings = self .rotary_embedding (hidden_states )
127
+ position_embeddings = self .rotary_embedding (hidden_states , start_index )
129
128
130
129
for decoder_layer in self .transformer_layers [:num_layers ]:
131
130
hidden_states = decoder_layer (
@@ -161,7 +160,6 @@ def __init__(
161
160
self .max_position_embeddings = max_position_embeddings
162
161
self .rope_theta = rope_theta
163
162
self .partial_rotary_factor = partial_rotary_factor
164
- self .rope_scaling = rope_scaling
165
163
166
164
def get_config (self ):
167
165
config = super ().get_config ()
@@ -182,7 +180,6 @@ def get_config(self):
182
180
"max_position_embeddings" : self .max_position_embeddings ,
183
181
"rope_theta" : self .rope_theta ,
184
182
"partial_rotary_factor" : self .partial_rotary_factor ,
185
- "rope_scaling" : self .rope_scaling
186
183
}
187
184
)
188
185
return config
0 commit comments