|
1 | 1 | from keras import activations
|
| 2 | +from keras import initializers |
2 | 3 | from keras import layers
|
3 | 4 | from keras import ops
|
4 | 5 |
|
5 | 6 | from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb
|
6 | 7 | from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward
|
| 8 | +from keras_hub.src.models.smollm3.smollm3_utils import rope_init |
7 | 9 |
|
8 | 10 |
|
9 | 11 | class SmolLM3Attention(layers.Layer):
|
@@ -242,3 +244,64 @@ def call(
|
242 | 244 | hidden_states = ops.add(residual, hidden_states)
|
243 | 245 |
|
244 | 246 | return hidden_states
|
| 247 | + |
| 248 | + |
| 249 | +class SmolLM3RotaryEmbedding(layers.Layer): |
| 250 | + def __init__( |
| 251 | + self, |
| 252 | + hidden_size: int, |
| 253 | + num_attention_heads: int, |
| 254 | + max_position_embeddings: int, |
| 255 | + rope_theta: float, |
| 256 | + partial_rotary_factor: float, |
| 257 | + **kwargs, |
| 258 | + ): |
| 259 | + super().__init__(**kwargs) |
| 260 | + self.hidden_size = hidden_size |
| 261 | + self.num_attention_heads = num_attention_heads |
| 262 | + self.max_position_embeddings = max_position_embeddings |
| 263 | + self.rope_theta = rope_theta |
| 264 | + self.partial_rotary_factor = partial_rotary_factor |
| 265 | + |
| 266 | + self.head_dim = self.hidden_size // self.num_attention_heads |
| 267 | + |
| 268 | + inv_freq_tensor, self.attention_scaling = rope_init( |
| 269 | + self.rope_theta, self.partial_rotary_factor, self.head_dim |
| 270 | + ) |
| 271 | + |
| 272 | + self.inv_freq = self.add_weight( |
| 273 | + name="inv_freq", |
| 274 | + shape=ops.shape(inv_freq_tensor), |
| 275 | + dtype=inv_freq_tensor.dtype, |
| 276 | + initializer=initializers.Constant( |
| 277 | + ops.convert_to_numpy(inv_freq_tensor) |
| 278 | + ), |
| 279 | + trainable=False, # This weight is not trained |
| 280 | + ) |
| 281 | + self.original_inv_freq = self.inv_freq |
| 282 | + |
| 283 | + def call(self, x, position_ids): |
| 284 | + inv_freq_expanded = ops.expand_dims( |
| 285 | + ops.expand_dims(self.inv_freq, axis=0), axis=-1 |
| 286 | + ) |
| 287 | + |
| 288 | + batch_size = ops.shape(position_ids)[0] |
| 289 | + inv_freq_expanded = ops.broadcast_to( |
| 290 | + inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) |
| 291 | + ) |
| 292 | + |
| 293 | + position_ids_expanded = ops.expand_dims(position_ids, axis=1) |
| 294 | + |
| 295 | + freqs = ops.matmul( |
| 296 | + ops.cast(inv_freq_expanded, "float32"), |
| 297 | + ops.cast(position_ids_expanded, "float32"), |
| 298 | + ) |
| 299 | + |
| 300 | + freqs = ops.transpose(freqs, axes=(0, 2, 1)) |
| 301 | + |
| 302 | + emb = ops.concatenate((freqs, freqs), axis=-1) |
| 303 | + |
| 304 | + cos = ops.cos(emb) * self.attention_scaling |
| 305 | + sin = ops.sin(emb) * self.attention_scaling |
| 306 | + |
| 307 | + return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype) |
0 commit comments