Skip to content

Commit 8aebfd1

Browse files
committed
refactor rotary embeddings
1 parent c8b7423 commit 8aebfd1

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from keras_hub.src.models.backbone import Backbone
99
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
1010
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding
11-
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
1211

1312

1413
@keras_hub_export(
@@ -70,7 +69,6 @@ def __init__(
7069
max_position_embeddings,
7170
rope_theta,
7271
partial_rotary_factor,
73-
rope_scaling=1,
7472
**kwargs,
7573
):
7674
# === Layers ===
@@ -102,17 +100,12 @@ def __init__(
102100
name="sequence_output_layernorm",
103101
)
104102

105-
#self.rotary_embedding = SmolLM3RotaryEmbedding(
106-
# hidden_size=hidden_dim,
107-
# num_attention_heads=num_attention_heads,
108-
# max_position_embeddings=max_position_embeddings,
109-
# rope_theta=rope_theta,
110-
# partial_rotary_factor=partial_rotary_factor,
111-
#)
112-
self.rotary_embedding = RotaryEmbedding(
113-
max_wavelength=rope_theta,
114-
scaling_factor=rope_scaling,
115-
dtype=self.token_embedding.dtype_policy
103+
self.rotary_embedding = SmolLM3RotaryEmbedding(
104+
hidden_size=hidden_dim,
105+
num_attention_heads=num_attention_heads,
106+
max_position_embeddings=max_position_embeddings,
107+
rope_theta=rope_theta,
108+
partial_rotary_factor=partial_rotary_factor,
116109
)
117110

118111
# === Functional Model ===
@@ -124,8 +117,14 @@ def __init__(
124117
shape=(None,), dtype="int32", name="padding_mask"
125118
)
126119

120+
cache_update_index = kwargs.get('self_attention_cache_index')
121+
122+
start_index = (
123+
cache_update_index if cache_update_index is not None else 0
124+
)
125+
127126
hidden_states = self.token_embedding(token_id_input)
128-
position_embeddings = self.rotary_embedding(hidden_states)
127+
position_embeddings = self.rotary_embedding(hidden_states, start_index)
129128

130129
for decoder_layer in self.transformer_layers[:num_layers]:
131130
hidden_states = decoder_layer(
@@ -161,7 +160,6 @@ def __init__(
161160
self.max_position_embeddings = max_position_embeddings
162161
self.rope_theta = rope_theta
163162
self.partial_rotary_factor = partial_rotary_factor
164-
self.rope_scaling = rope_scaling
165163

166164
def get_config(self):
167165
config = super().get_config()
@@ -182,7 +180,6 @@ def get_config(self):
182180
"max_position_embeddings": self.max_position_embeddings,
183181
"rope_theta": self.rope_theta,
184182
"partial_rotary_factor": self.partial_rotary_factor,
185-
"rope_scaling": self.rope_scaling
186183
}
187184
)
188185
return config

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def build(self, input_shape):
583583
def call(
584584
self,
585585
x,
586-
position_ids,
586+
start_index=0,
587587
):
588588
"""
589589
Forward pass for SmolLM3RotaryEmbedding.
@@ -596,13 +596,17 @@ def call(
596596
inv_freq_expanded = ops.expand_dims(
597597
ops.expand_dims(self.inv_freq, axis=0), axis=-1
598598
)
599+
600+
batch_size = ops.shape(x)[0]
601+
seq_len = ops.shape(x)[1]
602+
positions = ops.arange(seq_len, dtype="float32")
603+
positions + ops.cast(start_index, dtype="float32")
599604

600-
batch_size = ops.shape(position_ids)[0]
601605
inv_freq_expanded = ops.broadcast_to(
602606
inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1)
603607
)
604608

605-
position_ids_expanded = ops.expand_dims(position_ids, axis=1)
609+
position_ids_expanded = ops.expand_dims(positions, axis=1)
606610

607611
freqs = ops.matmul(
608612
ops.cast(inv_freq_expanded, "float32"),

0 commit comments

Comments
 (0)