Skip to content

Commit 598fd74

Browse files
committed
Add SmolLM3RotaryEmbedding
1 parent 2448d80 commit 598fd74

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

keras_hub/src/models/smollm3/smollm3_layers.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from keras import activations
2+
from keras import initializers
23
from keras import layers
34
from keras import ops
45

56
from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
67
from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward
8+
from keras_hub.src.models.smollm3.smollm3_utils import rope_init
79

810

911
class SmolLM3Attention(layers.Layer):
@@ -242,3 +244,64 @@ def call(
242244
hidden_states = ops.add(residual, hidden_states)
243245

244246
return hidden_states
247+
248+
249+
class SmolLM3RotaryEmbedding(layers.Layer):
250+
def __init__(
251+
self,
252+
hidden_size: int,
253+
num_attention_heads: int,
254+
max_position_embeddings: int,
255+
rope_theta: float,
256+
partial_rotary_factor: float,
257+
**kwargs,
258+
):
259+
super().__init__(**kwargs)
260+
self.hidden_size = hidden_size
261+
self.num_attention_heads = num_attention_heads
262+
self.max_position_embeddings = max_position_embeddings
263+
self.rope_theta = rope_theta
264+
self.partial_rotary_factor = partial_rotary_factor
265+
266+
self.head_dim = self.hidden_size // self.num_attention_heads
267+
268+
inv_freq_tensor, self.attention_scaling = rope_init(
269+
self.rope_theta, self.partial_rotary_factor, self.head_dim
270+
)
271+
272+
self.inv_freq = self.add_weight(
273+
name="inv_freq",
274+
shape=ops.shape(inv_freq_tensor),
275+
dtype=inv_freq_tensor.dtype,
276+
initializer=initializers.Constant(
277+
ops.convert_to_numpy(inv_freq_tensor)
278+
),
279+
trainable=False, # This weight is not trained
280+
)
281+
self.original_inv_freq = self.inv_freq
282+
283+
def call(self, x, position_ids):
284+
inv_freq_expanded = ops.expand_dims(
285+
ops.expand_dims(self.inv_freq, axis=0), axis=-1
286+
)
287+
288+
batch_size = ops.shape(position_ids)[0]
289+
inv_freq_expanded = ops.broadcast_to(
290+
inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1)
291+
)
292+
293+
position_ids_expanded = ops.expand_dims(position_ids, axis=1)
294+
295+
freqs = ops.matmul(
296+
ops.cast(inv_freq_expanded, "float32"),
297+
ops.cast(position_ids_expanded, "float32"),
298+
)
299+
300+
freqs = ops.transpose(freqs, axes=(0, 2, 1))
301+
302+
emb = ops.concatenate((freqs, freqs), axis=-1)
303+
304+
cos = ops.cos(emb) * self.attention_scaling
305+
sin = ops.sin(emb) * self.attention_scaling
306+
307+
return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype)

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,14 @@ def eager_attention_forward(
5656
attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3))
5757

5858
return attn_output, attn_weights
59+
60+
61+
def rope_init(rope_theta: float, partial_rotary_factor: float, head_dim: int):
62+
base = rope_theta
63+
dim = int(head_dim * partial_rotary_factor)
64+
65+
inv_freq = 1.0 / (
66+
ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim)
67+
)
68+
attention_scaling = 1.0
69+
return inv_freq, attention_scaling

0 commit comments

Comments
 (0)