From f2dedc4b11aaa115c2d4a8cabf73dc5cc8f9ea90 Mon Sep 17 00:00:00 2001 From: David Landup Date: Mon, 14 Jul 2025 21:46:31 +0900 Subject: [PATCH 01/76] add first few utils --- keras_hub/src/models/smollm3/smollm3_utils.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 keras_hub/src/models/smollm3/smollm3_utils.py diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py new file mode 100644 index 0000000000..15ebc5c31e --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -0,0 +1,24 @@ +from keras import ops + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return ops.concatenate((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1): + cos = ops.expand_dims(cos, expansion_axis) + sin = ops.expand_dims(sin, expansion_axis) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states, n_rep): + batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states) + if n_rep == 1: + return hidden_states + hidden_states = ops.expand_dims(hidden_states, axis=2) + target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = ops.broadcast_to(hidden_states, target_shape) + return ops.reshape(hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]) From 1d90715333fd83f5d78a7d56d4c3b93be2636842 Mon Sep 17 00:00:00 2001 From: David Landup Date: Mon, 14 Jul 2025 22:15:33 +0900 Subject: [PATCH 02/76] add eager attention forward --- keras_hub/src/models/smollm3/smollm3_utils.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 15ebc5c31e..e902d0872b 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -1,4 +1,5 @@ from keras import ops +from keras import random def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] @@ -22,3 +23,30 @@ def repeat_kv(hidden_states, n_rep): target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim) hidden_states = ops.broadcast_to(hidden_states, target_shape) return ops.reshape(hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]) + + +def eager_attention_forward( + module, + query, + key, + value, + attention_mask, + scaling: float, + dropout: float = 0.0, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = ops.matmul(query, ops.transpose(key_states, axes=(0, 1, 3, 2))) * scaling + + # Apply attention mask if provided + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] + attn_weights = ops.add(attn_weights, causal_mask) + + attn_weights = ops.softmax(attn_weights, axis=-1) + attn_weights = random.dropout(attn_weights, rate=dropout) + attn_output = ops.matmul(attn_weights, value_states) + attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3)) + + return attn_output, attn_weights From e5a8f3356f772972ee6a61c6d9146cf60e5c5b42 Mon Sep 17 00:00:00 2001 From: David Landup Date: Mon, 14 Jul 2025 22:44:43 +0900 Subject: [PATCH 03/76] Add SmolLM3Attention --- .../src/models/smollm3/smollm3_layers.py | 129 ++++++++++++++++++ keras_hub/src/models/smollm3/smollm3_utils.py | 12 +- 2 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 keras_hub/src/models/smollm3/smollm3_layers.py diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py new file mode 100644 index 0000000000..a1c3ab068d --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -0,0 +1,129 @@ +from keras import layers +from keras import ops +from keras.layers import Layer + +from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb +from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward + + +class SmolLM3Attention(Layer): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + attention_bias: bool, + attention_dropout: float, + no_rope_layers: list[bool], + layer_types: list[str], + _attn_implementation: str, + layer_idx: int, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.no_rope_layers = no_rope_layers + self.layer_types = layer_types + self._attn_implementation = _attn_implementation + + self.layer_idx = layer_idx + + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_groups = ( + self.num_attention_heads // self.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.is_causal = True + + self.q_proj = layers.Dense( + self.num_attention_heads * self.head_dim, + use_bias=self.attention_bias, + name="q_proj", + ) + self.k_proj = layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="k_proj", + ) + self.v_proj = layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="v_proj", + ) + self.o_proj = layers.Dense( + self.hidden_size, use_bias=self.attention_bias, name="o_proj" + ) + + self.use_rope = ( + self.no_rope_layers[self.layer_idx] + if self.layer_idx < len(self.no_rope_layers) + else True + ) # Default to True if index out of bounds + + self._attention_interface = eager_attention_forward + + def call( + self, + hidden_states, + position_embeddings, + attention_mask, + training=False, + **kwargs, + ): + self.training = training + + input_shape = ops.shape(hidden_states)[ + :-1 + ] # Exclude last dim (hidden_size) + + hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) + + query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) + query_states = ops.transpose( + query_states, axes=(0, 2, 1, 3) + ) # (batch, num_heads, seq_len, head_dim) + + # For key and value, the kv_hidden_shape should be based on num_key_value_heads + kv_hidden_shape = ( + *input_shape, + self.num_key_value_heads, + self.head_dim, + ) + key_states = ops.reshape(self.k_proj(hidden_states), kv_hidden_shape) + key_states = ops.transpose( + key_states, axes=(0, 2, 1, 3) + ) # (batch, num_key_value_heads, seq_len, head_dim) + + value_states = ops.reshape(self.v_proj(hidden_states), kv_hidden_shape) + value_states = ops.transpose( + value_states, axes=(0, 2, 1, 3) + ) # (batch, num_key_value_heads, seq_len, head_dim) + + if self.use_rope: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + attn_output, attn_weights = self._attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + dropout=self.attention_dropout, + scaling=self.scaling, + training=self.training, + **kwargs, + ) + + attn_output = ops.reshape(attn_output, (*input_shape, self.hidden_size)) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index e902d0872b..486be9f889 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -1,6 +1,7 @@ from keras import ops from keras import random + def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] @@ -22,7 +23,9 @@ def repeat_kv(hidden_states, n_rep): hidden_states = ops.expand_dims(hidden_states, axis=2) target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim) hidden_states = ops.broadcast_to(hidden_states, target_shape) - return ops.reshape(hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]) + return ops.reshape( + hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim] + ) def eager_attention_forward( @@ -37,8 +40,11 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = ops.matmul(query, ops.transpose(key_states, axes=(0, 1, 3, 2))) * scaling - + attn_weights = ( + ops.matmul(query, ops.transpose(key_states, axes=(0, 1, 3, 2))) + * scaling + ) + # Apply attention mask if provided if attention_mask is not None: causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] From 54191cae9ce4b9e2e09c67363ad27dc2cc3d908e Mon Sep 17 00:00:00 2001 From: David Landup Date: Mon, 14 Jul 2025 23:01:34 +0900 Subject: [PATCH 04/76] Add SmolLM3MLP --- .../src/models/smollm3/smollm3_layers.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index a1c3ab068d..43eb6b02b8 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -1,3 +1,4 @@ +from keras import activations from keras import layers from keras import ops from keras.layers import Layer @@ -127,3 +128,30 @@ def call( attn_output = self.o_proj(attn_output) return attn_output, attn_weights + + +class SmolLM3MLP(Layer): + def __init__( + self, hidden_size: int, intermediate_size: int, mlp_bias: bool, **kwargs + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.mlp_bias = mlp_bias + + self.gate_proj = layers.Dense( + self.intermediate_size, use_bias=self.mlp_bias, name="gate_proj" + ) + self.up_proj = layers.Dense( + self.intermediate_size, use_bias=self.mlp_bias, name="up_proj" + ) + self.down_proj = layers.Dense( + self.hidden_size, use_bias=self.mlp_bias, name="down_proj" + ) + + def call(self, x): + gate_output = activations.silu(self.gate_proj(x)) + up_output = self.up_proj(x) + intermediate_output = gate_output * up_output + down_proj_output = self.down_proj(intermediate_output) + return down_proj_output From 136973300ca5b6f8d425be73d7c9ba037f3705c5 Mon Sep 17 00:00:00 2001 From: David Landup Date: Mon, 14 Jul 2025 23:49:55 +0900 Subject: [PATCH 05/76] Add SmolLM3DecoderLayer --- .../src/models/smollm3/smollm3_layers.py | 97 ++++++++++++++++++- 1 file changed, 94 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 43eb6b02b8..d05cf37a29 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -1,13 +1,12 @@ from keras import activations from keras import layers from keras import ops -from keras.layers import Layer from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward -class SmolLM3Attention(Layer): +class SmolLM3Attention(layers.Layer): def __init__( self, hidden_size: int, @@ -130,7 +129,7 @@ def call( return attn_output, attn_weights -class SmolLM3MLP(Layer): +class SmolLM3MLP(layers.Layer): def __init__( self, hidden_size: int, intermediate_size: int, mlp_bias: bool, **kwargs ): @@ -155,3 +154,95 @@ def call(self, x): intermediate_output = gate_output * up_output down_proj_output = self.down_proj(intermediate_output) return down_proj_output + + +class SmolLM3DecoderLayer(layers.Layer): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + attention_bias: bool, + attention_dropout: float, + no_rope_layers: list[bool], + layer_types: list[str], + _attn_implementation: str, + layer_idx: int, + intermediate_size: int, # For MLP + mlp_bias: bool, # For MLP + rms_norm_eps: float, # For RMSNorm + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.layer_idx = layer_idx # Store layer_idx + + # Pass all necessary config parameters to SmolLM3AttentionKeras + self.self_attn = SmolLM3Attention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + no_rope_layers=no_rope_layers, + layer_types=layer_types, + _attn_implementation=_attn_implementation, + layer_idx=layer_idx, + name="self_attn", + ) + + self.mlp = SmolLM3MLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + mlp_bias=mlp_bias, + name="mlp", + ) + + self.input_layernorm = layers.RMSNormalization( + epsilon=rms_norm_eps, axis=-1, name="input_layernorm" + ) + self.post_attention_layernorm = layers.RMSNormalization( + epsilon=rms_norm_eps, axis=-1, name="post_attention_layernorm" + ) + + self.attention_type = layer_types[layer_idx] + + def build(self, input_shape): + # Build sub-layers + self.self_attn.build(input_shape) + self.mlp.build(input_shape) + self.input_layernorm.build(input_shape) + self.post_attention_layernorm.build(input_shape) + + super().build(input_shape) + + def call( + self, + hidden_states, + attention_mask=None, + position_embeddings=None, + training=False, # Keras layers have a 'training' argument in call + **kwargs, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_output, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, # Pass position_embeddings + training=training, # Pass training state + **kwargs, + ) + hidden_states = ops.add( + residual, attn_output + ) # Add attention output to residual + + # Fully Connected (MLP) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = ops.add(residual, hidden_states) + + return hidden_states From 2448d80f4ded0dd9dca1cd765610535f470fd60d Mon Sep 17 00:00:00 2001 From: David Landup Date: Mon, 14 Jul 2025 23:50:34 +0900 Subject: [PATCH 06/76] remove unnecessary comments --- .../src/models/smollm3/smollm3_layers.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index d05cf37a29..89b7fc9a16 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -168,16 +168,15 @@ def __init__( layer_types: list[str], _attn_implementation: str, layer_idx: int, - intermediate_size: int, # For MLP - mlp_bias: bool, # For MLP - rms_norm_eps: float, # For RMSNorm + intermediate_size: int, + mlp_bias: bool, + rms_norm_eps: float, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size - self.layer_idx = layer_idx # Store layer_idx + self.layer_idx = layer_idx - # Pass all necessary config parameters to SmolLM3AttentionKeras self.self_attn = SmolLM3Attention( hidden_size=hidden_size, num_attention_heads=num_attention_heads, @@ -221,7 +220,7 @@ def call( hidden_states, attention_mask=None, position_embeddings=None, - training=False, # Keras layers have a 'training' argument in call + training=False, **kwargs, ): residual = hidden_states @@ -231,15 +230,12 @@ def call( attn_output, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_embeddings=position_embeddings, # Pass position_embeddings - training=training, # Pass training state + position_embeddings=position_embeddings, + training=training, **kwargs, ) - hidden_states = ops.add( - residual, attn_output - ) # Add attention output to residual + hidden_states = ops.add(residual, attn_output) - # Fully Connected (MLP) residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) From 598fd7479525a4597e384d3903dc382468754c09 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 15 Jul 2025 00:10:32 +0900 Subject: [PATCH 07/76] Add SmolLM3RotaryEmbedding --- .../src/models/smollm3/smollm3_layers.py | 63 +++++++++++++++++++ keras_hub/src/models/smollm3/smollm3_utils.py | 11 ++++ 2 files changed, 74 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 89b7fc9a16..3653d270f4 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -1,9 +1,11 @@ from keras import activations +from keras import initializers from keras import layers from keras import ops from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward +from keras_hub.src.models.smollm3.smollm3_utils import rope_init class SmolLM3Attention(layers.Layer): @@ -242,3 +244,64 @@ def call( hidden_states = ops.add(residual, hidden_states) return hidden_states + + +class SmolLM3RotaryEmbedding(layers.Layer): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + max_position_embeddings: int, + rope_theta: float, + partial_rotary_factor: float, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + + self.head_dim = self.hidden_size // self.num_attention_heads + + inv_freq_tensor, self.attention_scaling = rope_init( + self.rope_theta, self.partial_rotary_factor, self.head_dim + ) + + self.inv_freq = self.add_weight( + name="inv_freq", + shape=ops.shape(inv_freq_tensor), + dtype=inv_freq_tensor.dtype, + initializer=initializers.Constant( + ops.convert_to_numpy(inv_freq_tensor) + ), + trainable=False, # This weight is not trained + ) + self.original_inv_freq = self.inv_freq + + def call(self, x, position_ids): + inv_freq_expanded = ops.expand_dims( + ops.expand_dims(self.inv_freq, axis=0), axis=-1 + ) + + batch_size = ops.shape(position_ids)[0] + inv_freq_expanded = ops.broadcast_to( + inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) + ) + + position_ids_expanded = ops.expand_dims(position_ids, axis=1) + + freqs = ops.matmul( + ops.cast(inv_freq_expanded, "float32"), + ops.cast(position_ids_expanded, "float32"), + ) + + freqs = ops.transpose(freqs, axes=(0, 2, 1)) + + emb = ops.concatenate((freqs, freqs), axis=-1) + + cos = ops.cos(emb) * self.attention_scaling + sin = ops.sin(emb) * self.attention_scaling + + return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 486be9f889..ed5f0c4de1 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -56,3 +56,14 @@ def eager_attention_forward( attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3)) return attn_output, attn_weights + + +def rope_init(rope_theta: float, partial_rotary_factor: float, head_dim: int): + base = rope_theta + dim = int(head_dim * partial_rotary_factor) + + inv_freq = 1.0 / ( + ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim) + ) + attention_scaling = 1.0 + return inv_freq, attention_scaling From b9e458d27a96a26a30aa65e684875ea39c5a4154 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 15 Jul 2025 00:29:15 +0900 Subject: [PATCH 08/76] add most of smollm3backbone --- .../src/models/smollm3/smollm3_backbone.py | 153 ++++++++++++++++++ .../src/models/smollm3/smollm3_layers.py | 16 +- 2 files changed, 159 insertions(+), 10 deletions(-) create mode 100644 keras_hub/src/models/smollm3/smollm3_backbone.py diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py new file mode 100644 index 0000000000..877c466f23 --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -0,0 +1,153 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer +from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding + + +@keras_hub_export( + [ + "keras_hub.models.SmolLM3Backbone", + "keras_hub.models.SmolLMBackbone", + ] +) +class SmolLM3Backbone(Backbone): + """ + The SmolLM Transformer core architecture with hyperparameters. + + This network implements a Transformer-based decoder network, + SmolLM3, as described in the SmolLM3 model architecture. + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + SmolLM3 model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained SmolLM decoder. + model = keras_hub.models.SmolLM3Backbone.from_preset("...") + model(input_data) + + # Randomly initialized SmolLM3 decoder with custom config. + model = keras_hub.models.SmolLM3Backbone( + ... + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + hidden_dim, + intermediate_dim, + num_layers, + num_attention_heads, + num_key_value_heads, + attention_bias, + attention_dropout, + rope_layer_enabled_list, + layer_types, + mlp_bias, + rms_norm_epsilon, + layer_norm_epsilon, + max_position_embeddings, + rope_theta, + partial_rotary_factor, + **kwargs, + ): + # === Layers === + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + name="token_embedding", + ) + self.transformer_layers = [] + + for i in range(num_layers): + layer = SmolLM3DecoderLayer( + hidden_size=hidden_dim, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rope_layer_enabled_list=rope_layer_enabled_list, + layer_types=layer_types, + layer_idx=i, + intermediate_size=intermediate_dim, + mlp_bias=mlp_bias, + rms_norm_epsilon=rms_norm_epsilon, + ) + self.transformer_layers.append(layer) + + self.norm = keras.layers.RMSNormalization( + epsilon=layer_norm_epsilon, + name="sequence_output_layernorm", + ) + + self.rotary_embedding = SmolLM3RotaryEmbedding( + hidden_size=hidden_dim, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + rope_theta=rope_theta, + partial_rotary_factor=partial_rotary_factor, + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + position_embeddings = self.rotary_embedding(x) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=#createcausalmask, + position_embeddings=position_embeddings, + **kwargs, + ) + + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + } + ) + return config + diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 3653d270f4..81546dd89b 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -16,9 +16,8 @@ def __init__( num_key_value_heads: int, attention_bias: bool, attention_dropout: float, - no_rope_layers: list[bool], + rope_layer_enabled_list: list[bool], layer_types: list[str], - _attn_implementation: str, layer_idx: int, **kwargs, ): @@ -29,9 +28,8 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.no_rope_layers = no_rope_layers + self.rope_layer_enabled_list = rope_layer_enabled_list self.layer_types = layer_types - self._attn_implementation = _attn_implementation self.layer_idx = layer_idx @@ -62,8 +60,8 @@ def __init__( ) self.use_rope = ( - self.no_rope_layers[self.layer_idx] - if self.layer_idx < len(self.no_rope_layers) + self.rope_layer_enabled_list[self.layer_idx] + if self.layer_idx < len(self.rope_layer_enabled_list) else True ) # Default to True if index out of bounds @@ -166,9 +164,8 @@ def __init__( num_key_value_heads: int, attention_bias: bool, attention_dropout: float, - no_rope_layers: list[bool], + rope_layer_enabled_list: list[bool], layer_types: list[str], - _attn_implementation: str, layer_idx: int, intermediate_size: int, mlp_bias: bool, @@ -185,9 +182,8 @@ def __init__( num_key_value_heads=num_key_value_heads, attention_bias=attention_bias, attention_dropout=attention_dropout, - no_rope_layers=no_rope_layers, + rope_layer_enabled_list=rope_layer_enabled_list, layer_types=layer_types, - _attn_implementation=_attn_implementation, layer_idx=layer_idx, name="self_attn", ) From 6a53a7d437e0937399981d299fe04ed327e350e5 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 19:49:52 +0900 Subject: [PATCH 09/76] Fix calls within causal model --- .../src/models/smollm3/smollm3_backbone.py | 24 ++++++++++++------- .../src/models/smollm3/smollm3_layers.py | 6 ++--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 877c466f23..e1b09b0f16 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -1,6 +1,9 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding @@ -66,6 +69,7 @@ def __init__( max_position_embeddings, rope_theta, partial_rotary_factor, + num_hidden_layers, **kwargs, ): # === Layers === @@ -109,16 +113,21 @@ def __init__( token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" ) - padding_mask_input = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" + position_ids = keras.Input( + shape=(None,), dtype="int32", name="position_ids" ) - x = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(x) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = self.token_embedding(token_id_input) + position_embeddings = self.rotary_embedding(hidden_states, position_ids) + + for decoder_layer in self.layers[:num_hidden_layers]: hidden_states = decoder_layer( hidden_states, - attention_mask=#createcausalmask, + attention_mask=compute_causal_mask( + hidden_states.shape[0], + hidden_states.shape[1], + hidden_states.shape[1], + ), position_embeddings=position_embeddings, **kwargs, ) @@ -127,7 +136,6 @@ def __init__( super().__init__( inputs={ "token_ids": token_id_input, - "padding_mask": padding_mask_input, }, outputs=sequence_output, **kwargs, @@ -137,7 +145,6 @@ def __init__( self.vocabulary_size = vocabulary_size self.num_layers = num_layers - def get_config(self): config = super().get_config() config.update( @@ -150,4 +157,3 @@ def get_config(self): } ) return config - diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 81546dd89b..0c5542163c 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -169,7 +169,7 @@ def __init__( layer_idx: int, intermediate_size: int, mlp_bias: bool, - rms_norm_eps: float, + rms_norm_epsilon: float, **kwargs, ): super().__init__(**kwargs) @@ -196,10 +196,10 @@ def __init__( ) self.input_layernorm = layers.RMSNormalization( - epsilon=rms_norm_eps, axis=-1, name="input_layernorm" + epsilon=rms_norm_epsilon, axis=-1, name="input_layernorm" ) self.post_attention_layernorm = layers.RMSNormalization( - epsilon=rms_norm_eps, axis=-1, name="post_attention_layernorm" + epsilon=rms_norm_epsilon, axis=-1, name="post_attention_layernorm" ) self.attention_type = layer_types[layer_idx] From 81eff7358217c62b051806fbddaeb1cd599a8993 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 20:24:19 +0900 Subject: [PATCH 10/76] Move causal mask computation to forward call --- .../src/models/smollm3/smollm3_backbone.py | 18 ++++++------------ keras_hub/src/models/smollm3/smollm3_layers.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index e1b09b0f16..78f4039703 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -1,9 +1,6 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.modeling.transformer_layer_utils import ( - compute_causal_mask, -) from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding @@ -78,7 +75,7 @@ def __init__( output_dim=hidden_dim, name="token_embedding", ) - self.transformer_layers = [] + self.decoder_layers = [] for i in range(num_layers): layer = SmolLM3DecoderLayer( @@ -94,7 +91,7 @@ def __init__( mlp_bias=mlp_bias, rms_norm_epsilon=rms_norm_epsilon, ) - self.transformer_layers.append(layer) + self.decoder_layers.append(layer) self.norm = keras.layers.RMSNormalization( epsilon=layer_norm_epsilon, @@ -117,22 +114,19 @@ def __init__( shape=(None,), dtype="int32", name="position_ids" ) + print("token id", token_id_input.shape) hidden_states = self.token_embedding(token_id_input) + print("hidden states id", hidden_states.shape) position_embeddings = self.rotary_embedding(hidden_states, position_ids) - for decoder_layer in self.layers[:num_hidden_layers]: + for decoder_layer in self.decoder_layers[:num_hidden_layers]: hidden_states = decoder_layer( hidden_states, - attention_mask=compute_causal_mask( - hidden_states.shape[0], - hidden_states.shape[1], - hidden_states.shape[1], - ), position_embeddings=position_embeddings, **kwargs, ) - sequence_output = self.layer_norm(x) + sequence_output = self.layer_norm(hidden_states) super().__init__( inputs={ "token_ids": token_id_input, diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 0c5542163c..d397e887c0 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -3,6 +3,9 @@ from keras import layers from keras import ops +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward from keras_hub.src.models.smollm3.smollm3_utils import rope_init @@ -216,7 +219,6 @@ def build(self, input_shape): def call( self, hidden_states, - attention_mask=None, position_embeddings=None, training=False, **kwargs, @@ -224,6 +226,14 @@ def call( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + attention_mask = ( + compute_causal_mask( + ops.shape(hidden_states)[0], + ops.shape(hidden_states)[1], + ops.shape(hidden_states)[1], + ), + ) + # Self Attention attn_output, _ = self.self_attn( hidden_states=hidden_states, From b0080f2df5ebcc4ec6767bf9672e7c5deb6c8c82 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 20:51:19 +0900 Subject: [PATCH 11/76] Add convert_smollm3.py and update preset loader --- .../src/models/smollm3/smollm3_backbone.py | 6 +- .../src/models/smollm3/smollm3_layers.py | 4 +- keras_hub/src/models/smollm3/smollm3_utils.py | 8 +- .../src/utils/transformers/convert_smollm3.py | 157 ++++++++++++++++++ .../src/utils/transformers/preset_loader.py | 3 + 5 files changed, 167 insertions(+), 11 deletions(-) create mode 100644 keras_hub/src/utils/transformers/convert_smollm3.py diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 78f4039703..c7cfba9264 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -61,7 +61,6 @@ def __init__( rope_layer_enabled_list, layer_types, mlp_bias, - rms_norm_epsilon, layer_norm_epsilon, max_position_embeddings, rope_theta, @@ -89,7 +88,7 @@ def __init__( layer_idx=i, intermediate_size=intermediate_dim, mlp_bias=mlp_bias, - rms_norm_epsilon=rms_norm_epsilon, + rms_norm_epsilon=layer_norm_epsilon, ) self.decoder_layers.append(layer) @@ -145,9 +144,6 @@ def get_config(self): { "vocabulary_size": self.vocabulary_size, "num_layers": self.num_layers, - "num_query_heads": self.num_query_heads, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, } ) return config diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index d397e887c0..9a89bc05b4 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -68,8 +68,6 @@ def __init__( else True ) # Default to True if index out of bounds - self._attention_interface = eager_attention_forward - def call( self, hidden_states, @@ -113,7 +111,7 @@ def call( query_states, key_states, cos, sin ) - attn_output, attn_weights = self._attention_interface( + attn_output, attn_weights = eager_attention_forward( module=self, query=query_states, key=key_states, diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index ed5f0c4de1..6edce6bdb7 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -34,8 +34,9 @@ def eager_attention_forward( key, value, attention_mask, - scaling: float, - dropout: float = 0.0, + scaling, + dropout=0.0, + training=False, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -51,7 +52,8 @@ def eager_attention_forward( attn_weights = ops.add(attn_weights, causal_mask) attn_weights = ops.softmax(attn_weights, axis=-1) - attn_weights = random.dropout(attn_weights, rate=dropout) + if not training: + attn_weights = random.dropout(attn_weights, rate=dropout) attn_output = ops.matmul(attn_weights, value_states) attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3)) diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py new file mode 100644 index 0000000000..cb62719e75 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -0,0 +1,157 @@ +import numpy as np + +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = SmolLM3Backbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "hidden_dim": transformers_config["hidden_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_attention_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "layer_norm_epsilon": transformers_config[ + "rms_norm_eps" + ], # Using rms_norm_eps as layer_norm_epsilon + "max_position_embeddings": transformers_config[ + "max_position_embeddings" + ], + "rope_theta": transformers_config["rope_theta"], + # partial_rotary_factor is not explicitly in config.json + # but is inherited from the default value in the `_compute_default_rope_parameters()` + # function + "partial_rotary_factor": 1.0, + "attention_bias": transformers_config["attention_bias"], + "attention_dropout": transformers_config["attention_dropout"], + "rope_layer_enabled_list": transformers_config["no_rope_layers"], + "layer_types": transformers_config["layer_types"], + "mlp_bias": transformers_config["mlp_bias"], + "num_hidden_layers": transformers_config[ + "num_hidden_layers" + ], # Redundant with num_layers, but kept for completeness + } + + +def convert_weights(backbone, loader, transformers_config): + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + # rearrange_pattern="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention layers + + ## Query + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale, + hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight", + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale, + hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight", + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + # rearrange_patterns="c (a b) -> a b c", + # rearrange_dims={"a": backbone.num_query_heads}, + hook_fn=transpose_and_reshape, + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + merges = [" ".join(item) for item in merges] + + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index fe49a9b269..526922505d 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -17,6 +17,7 @@ from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen_moe +from keras_hub.src.utils.transformers import convert_smollm3 from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -56,6 +57,8 @@ def __init__(self, preset, config): self.converter = convert_qwen_moe elif model_type == "qwen3": self.converter = convert_qwen3 + elif model_type == "smollm3": + self.converter = convert_smollm3 else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " From d5767c16f65a5945e03844ad2a127da5a9d92f3b Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 21:15:05 +0900 Subject: [PATCH 12/76] Fix causal mask call --- .../src/models/smollm3/smollm3_backbone.py | 12 +- .../src/models/smollm3/smollm3_layers.py | 204 +++++++++++++++++- keras_hub/src/models/smollm3/smollm3_utils.py | 4 +- .../src/utils/transformers/convert_smollm3.py | 9 - 4 files changed, 202 insertions(+), 27 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index c7cfba9264..58d9fe1459 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -89,6 +89,7 @@ def __init__( intermediate_size=intermediate_dim, mlp_bias=mlp_bias, rms_norm_epsilon=layer_norm_epsilon, + name=f"transformer_layer_{i}", ) self.decoder_layers.append(layer) @@ -109,14 +110,14 @@ def __init__( token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" ) - position_ids = keras.Input( + position_id_input = keras.Input( shape=(None,), dtype="int32", name="position_ids" ) - print("token id", token_id_input.shape) hidden_states = self.token_embedding(token_id_input) - print("hidden states id", hidden_states.shape) - position_embeddings = self.rotary_embedding(hidden_states, position_ids) + position_embeddings = self.rotary_embedding( + hidden_states, position_id_input + ) for decoder_layer in self.decoder_layers[:num_hidden_layers]: hidden_states = decoder_layer( @@ -125,10 +126,11 @@ def __init__( **kwargs, ) - sequence_output = self.layer_norm(hidden_states) + sequence_output = self.norm(hidden_states) super().__init__( inputs={ "token_ids": token_id_input, + "position_ids": position_id_input, }, outputs=sequence_output, **kwargs, diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 9a89bc05b4..175c7631f0 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -12,6 +12,20 @@ class SmolLM3Attention(layers.Layer): + """ + Multi-head attention layer for SmolLM3 model. + + Args: + hidden_size: The hidden size of the attention layer. + num_attention_heads: The number of attention heads. + num_key_value_heads: The number of key-value heads. + attention_bias: Whether to use bias in attention projections. + attention_dropout: Dropout rate for attention weights. + rope_layer_enabled_list: List indicating if RoPE is enabled for each layer. + layer_types: List of layer types. + layer_idx: Index of the current layer. + """ + def __init__( self, hidden_size: int, @@ -76,15 +90,25 @@ def call( training=False, **kwargs, ): + """ + Forward pass for SmolLM3Attention. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size). + position_embeddings: Tuple of (cos, sin) tensors for RoPE. + attention_mask: Attention mask tensor. + training: Whether the layer is in training mode. + """ self.training = training input_shape = ops.shape(hidden_states)[ :-1 ] # Exclude last dim (hidden_size) - hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) - - query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) + query_states = ops.reshape( + self.q_proj(hidden_states), + (*input_shape, self.num_attention_heads, self.head_dim), + ) query_states = ops.transpose( query_states, axes=(0, 2, 1, 3) ) # (batch, num_heads, seq_len, head_dim) @@ -129,8 +153,47 @@ def call( return attn_output, attn_weights + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: A list/tuple of shapes for the inputs: + [hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape] + - hidden_states_shape: (batch_size, seq_len, hidden_size) + - position_embeddings_shape_tuple: (cos_shape, sin_shape) where cos_shape/sin_shape is (batch_size, seq_len, head_dim) + - attention_mask_shape: (batch_size, 1, seq_len, seq_len) + + Returns: + A list of output shapes: [output_attn_output_shape, output_attn_weights_shape] + """ + hidden_states_shape = input_shape[0] + + batch_size = hidden_states_shape[0] + seq_len = hidden_states_shape[1] + + output_attn_output_shape = (batch_size, seq_len, self.hidden_size) + + output_attn_weights_shape = ( + batch_size, + self.num_attention_heads, + seq_len, + seq_len, + ) + + return [output_attn_output_shape, output_attn_weights_shape] + class SmolLM3MLP(layers.Layer): + """ + Multi-layer perceptron (MLP) block for SmolLM3 model. + + Args: + hidden_size: The hidden size of the MLP. + intermediate_size: The intermediate size of the MLP. + mlp_bias: Whether to use bias in MLP dense layers. + """ + def __init__( self, hidden_size: int, intermediate_size: int, mlp_bias: bool, **kwargs ): @@ -150,14 +213,50 @@ def __init__( ) def call(self, x): + """ + Forward pass for SmolLM3MLP. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size). + """ gate_output = activations.silu(self.gate_proj(x)) up_output = self.up_proj(x) intermediate_output = gate_output * up_output down_proj_output = self.down_proj(intermediate_output) return down_proj_output + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: The input shape (batch_size, seq_len, hidden_size). + + Returns: + The output shape, which is the same as the input shape: + (batch_size, seq_len, hidden_size). + """ + return input_shape + class SmolLM3DecoderLayer(layers.Layer): + """ + Decoder layer for SmolLM3 model, combining self-attention and MLP. + + Args: + hidden_size: The hidden size of the layer. + num_attention_heads: The number of attention heads. + num_key_value_heads: The number of key-value heads. + attention_bias: Whether to use bias in attention projections. + attention_dropout: Dropout rate for attention weights. + rope_layer_enabled_list: List indicating if RoPE is enabled for each layer. + layer_types: List of layer types. + layer_idx: Index of the current layer. + intermediate_size: The intermediate size of the MLP. + mlp_bias: Whether to use bias in MLP dense layers. + rms_norm_epsilon: Epsilon for RMSNormalization. + """ + def __init__( self, hidden_size: int, @@ -206,8 +305,25 @@ def __init__( self.attention_type = layer_types[layer_idx] def build(self, input_shape): - # Build sub-layers - self.self_attn.build(input_shape) + """ + Builds the sub-layers based on the input shape. + + Args: + input_shape: The input shape to the decoder layer + (batch_size, seq_len, hidden_size). + """ + # input_shape for SmolLM3DecoderLayer: (batch_size, seq_len, hidden_size) + batch_size = input_shape[0] + seq_len = input_shape[1] + + head_dim = self.self_attn.head_dim + pos_emb_shape = (batch_size, seq_len, head_dim) + + attn_mask_shape = (batch_size, 1, seq_len, seq_len) + + self.self_attn.build( + [input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape] + ) self.mlp.build(input_shape) self.input_layernorm.build(input_shape) self.post_attention_layernorm.build(input_shape) @@ -221,15 +337,21 @@ def call( training=False, **kwargs, ): + """ + Forward pass for SmolLM3DecoderLayer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size). + position_embeddings: Optional tuple of (cos, sin) tensors for RoPE. + training: Whether the layer is in training mode. + """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - attention_mask = ( - compute_causal_mask( - ops.shape(hidden_states)[0], - ops.shape(hidden_states)[1], - ops.shape(hidden_states)[1], - ), + attention_mask = compute_causal_mask( + ops.shape(hidden_states)[0], + ops.shape(hidden_states)[1], + ops.shape(hidden_states)[1], ) # Self Attention @@ -249,8 +371,32 @@ def call( return hidden_states + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: The input shape (batch_size, seq_len, hidden_size). + + Returns: + The output shape, which is the same as the input shape: + (batch_size, seq_len, hidden_size). + """ + return input_shape + class SmolLM3RotaryEmbedding(layers.Layer): + """ + Rotary Position Embedding (RoPE) layer for SmolLM3 model. + + Args: + hidden_size: The hidden size of the model. + num_attention_heads: The number of attention heads. + max_position_embeddings: The maximum sequence length for position embeddings. + rope_theta: The theta value for RoPE. + partial_rotary_factor: The factor for partial rotary embedding. + """ + def __init__( self, hidden_size: int, @@ -285,6 +431,14 @@ def __init__( self.original_inv_freq = self.inv_freq def call(self, x, position_ids): + """ + Forward pass for SmolLM3RotaryEmbedding. + + Args: + x: Input tensor, typically query or key states. + Shape can vary, but the last dimension is head_dim. + position_ids: Tensor of position IDs of shape (batch_size, seq_len). + """ inv_freq_expanded = ops.expand_dims( ops.expand_dims(self.inv_freq, axis=0), axis=-1 ) @@ -309,3 +463,31 @@ def call(self, x, position_ids): sin = ops.sin(emb) * self.attention_scaling return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype) + + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: A list/tuple of shapes for the inputs: + [x_shape, position_ids_shape] + - x_shape: (batch_size, ..., head_dim) + - position_ids_shape: (batch_size, seq_len) + + Returns: + A list of output shapes for (cos, sin): + [(batch_size, seq_len, head_dim), (batch_size, seq_len, head_dim)] + """ + if input_shape[1] is not None and len(input_shape[1]) >= 2: + batch_size = input_shape[1][0] + seq_len = input_shape[1][1] + else: + # Fallback if position_ids_shape is None or malformed. + # In this case, the batch_size and seq_len are unknown. + batch_size = None + seq_len = None + + # The output cos and sin have shape (batch_size, seq_len, head_dim) + output_shape = (batch_size, seq_len, self.head_dim) + + return [output_shape, output_shape] diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 6edce6bdb7..67db41f854 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -48,8 +48,8 @@ def eager_attention_forward( # Apply attention mask if provided if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] - attn_weights = ops.add(attn_weights, causal_mask) + # causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] + attn_weights = ops.add(attn_weights, attention_mask) attn_weights = ops.softmax(attn_weights, axis=-1) if not training: diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py index cb62719e75..b3f21f004f 100644 --- a/keras_hub/src/utils/transformers/convert_smollm3.py +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -41,15 +41,6 @@ def convert_weights(backbone, loader, transformers_config): keras_variable=backbone.get_layer("token_embedding").embeddings, hf_weight_key="model.embed_tokens.weight", ) - if not backbone.tie_word_embeddings: - loader.port_weight( - keras_variable=backbone.get_layer( - "token_embedding" - ).reverse_embeddings, - hf_weight_key="lm_head.weight", - # rearrange_pattern="b a -> a b", - hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), - ) def transpose_and_reshape(x, shape): return np.reshape(np.transpose(x), shape) From 186eaf8595453c7a44c41c58819a7178dcbdb989 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 21:54:14 +0900 Subject: [PATCH 13/76] Fix conversion weight names --- keras_hub/src/models/smollm3/smollm3_utils.py | 1 - .../src/utils/transformers/convert_smollm3.py | 25 ++++++++----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 67db41f854..bb852a6915 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -48,7 +48,6 @@ def eager_attention_forward( # Apply attention mask if provided if attention_mask is not None: - # causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] attn_weights = ops.add(attn_weights, attention_mask) attn_weights = ops.softmax(attn_weights, axis=-1) diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py index b3f21f004f..756d74b623 100644 --- a/keras_hub/src/utils/transformers/convert_smollm3.py +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -30,9 +30,6 @@ def convert_backbone_config(transformers_config): "rope_layer_enabled_list": transformers_config["no_rope_layers"], "layer_types": transformers_config["layer_types"], "mlp_bias": transformers_config["mlp_bias"], - "num_hidden_layers": transformers_config[ - "num_hidden_layers" - ], # Redundant with num_layers, but kept for completeness } @@ -50,7 +47,7 @@ def transpose_and_reshape(x, shape): # Input layernorm loader.port_weight( - keras_variable=decoder_layer._self_attention_layernorm.scale, + keras_variable=decoder_layer.input_layernorm.scale, hf_weight_key=f"model.layers.{i}.input_layernorm.weight", ) @@ -58,33 +55,33 @@ def transpose_and_reshape(x, shape): ## Query loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + keras_variable=decoder_layer.self_attn.q_proj.kernel, hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", hook_fn=transpose_and_reshape, ) loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale, + keras_variable=decoder_layer.self_attn.q_norm.scale, hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight", ) ## Key loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + keras_variable=decoder_layer.self_attn.k_proj.kernel, hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", hook_fn=transpose_and_reshape, ) loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale, + keras_variable=decoder_layer.self_attn.k_norm.scale, hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight", ) ## Value loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + keras_variable=decoder_layer.self_attn.v_proj.kernel, hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", hook_fn=transpose_and_reshape, ) ## Output loader.port_weight( - keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + keras_variable=decoder_layer.self_attn.o_proj.kernel, hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", # rearrange_patterns="c (a b) -> a b c", # rearrange_dims={"a": backbone.num_query_heads}, @@ -93,19 +90,19 @@ def transpose_and_reshape(x, shape): # MLP layers loader.port_weight( - keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + keras_variable=decoder_layer.mlp.up_proj.kernel, hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", # rearrange_patterns="b a -> a b", hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), ) loader.port_weight( - keras_variable=decoder_layer._feedforward_output_dense.kernel, + keras_variable=decoder_layer.mlp.down_proj.kernel, hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", # rearrange_patterns="b a -> a b", hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), ) loader.port_weight( - keras_variable=decoder_layer._feedforward_gate_dense.kernel, + keras_variable=decoder_layer.mlp.gate_proj.kernel, hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", # rearrange_patterns="b a -> a b", hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), @@ -113,7 +110,7 @@ def transpose_and_reshape(x, shape): # Feedforward layernorm loader.port_weight( - keras_variable=decoder_layer._feedforward_layernorm.scale, + keras_variable=decoder_layer.post_attention_layernorm.scale, hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", ) From 6ab2e5c4790216a2afec07ba897362e9d897ba56 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 21:58:54 +0900 Subject: [PATCH 14/76] remove unnecessary arg --- keras_hub/src/models/smollm3/smollm3_backbone.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 58d9fe1459..5b8848da22 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -65,7 +65,6 @@ def __init__( max_position_embeddings, rope_theta, partial_rotary_factor, - num_hidden_layers, **kwargs, ): # === Layers === @@ -119,7 +118,7 @@ def __init__( hidden_states, position_id_input ) - for decoder_layer in self.decoder_layers[:num_hidden_layers]: + for decoder_layer in self.decoder_layers[:num_layers]: hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, From 6819fd1bdd8be1048b2826e3b74ee17455053173 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 22:10:11 +0900 Subject: [PATCH 15/76] Build all layers --- .../src/models/smollm3/smollm3_layers.py | 59 ++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 175c7631f0..2c60174324 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -82,6 +82,23 @@ def __init__( else True ) # Default to True if index out of bounds + def build(self, input_shape): + """ + Builds the internal Dense layers. + Args: + input_shape: A list/tuple of shapes for the inputs: + [hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape] + - hidden_states_shape: (batch_size, seq_len, hidden_size) + """ + # The input shape to the Dense layers (q_proj, k_proj, v_proj, o_proj) + # is the same as the hidden_states input to SmolLM3Attention. + hidden_states_shape = input_shape[0] + self.q_proj.build(hidden_states_shape) + self.k_proj.build(hidden_states_shape) + self.v_proj.build(hidden_states_shape) + self.o_proj.build(hidden_states_shape) + super().build(input_shape) + def call( self, hidden_states, @@ -212,6 +229,25 @@ def __init__( self.hidden_size, use_bias=self.mlp_bias, name="down_proj" ) + def build(self, input_shape): + """ + Builds the internal Dense layers. + Args: + input_shape: The shape of the input to this layer + (batch_size, seq_len, hidden_size). + """ + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + # The down_proj takes intermediate_output, which has shape + # (batch_size, seq_len, intermediate_size) + down_proj_input_shape = ( + input_shape[0], + input_shape[1], + self.intermediate_size, + ) + self.down_proj.build(down_proj_input_shape) + super().build(input_shape) + def call(self, x): """ Forward pass for SmolLM3MLP. @@ -321,9 +357,13 @@ def build(self, input_shape): attn_mask_shape = (batch_size, 1, seq_len, seq_len) + # Pass the correct input shape to self_attn's build method + # The input_shape for self_attn.build is a list: + # [hidden_states_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape] self.self_attn.build( [input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape] ) + self.mlp.build(input_shape) self.input_layernorm.build(input_shape) self.post_attention_layernorm.build(input_shape) @@ -430,7 +470,24 @@ def __init__( ) self.original_inv_freq = self.inv_freq - def call(self, x, position_ids): + def build(self, input_shape): + """ + Builds the layer. For SmolLM3RotaryEmbedding, this mainly ensures + that the parent layer's build is called. + Args: + input_shape: A list/tuple of shapes for the inputs: + [x_shape, position_ids_shape] + - x_shape: (batch_size, ..., head_dim) + - position_ids_shape: (batch_size, seq_len) + """ + # No internal layers to explicitly build here, as inv_freq is added in __init__ + super().build(input_shape) + + def call( + self, + x, + position_ids, + ): """ Forward pass for SmolLM3RotaryEmbedding. From e126938b51f01ea2314724d3642535233402567d Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 22:12:33 +0900 Subject: [PATCH 16/76] Remove k and q norms --- keras_hub/src/utils/transformers/convert_smollm3.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py index 756d74b623..d93e0c1361 100644 --- a/keras_hub/src/utils/transformers/convert_smollm3.py +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -59,20 +59,12 @@ def transpose_and_reshape(x, shape): hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", hook_fn=transpose_and_reshape, ) - loader.port_weight( - keras_variable=decoder_layer.self_attn.q_norm.scale, - hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight", - ) ## Key loader.port_weight( keras_variable=decoder_layer.self_attn.k_proj.kernel, hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", hook_fn=transpose_and_reshape, ) - loader.port_weight( - keras_variable=decoder_layer.self_attn.k_norm.scale, - hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight", - ) ## Value loader.port_weight( keras_variable=decoder_layer.self_attn.v_proj.kernel, From 26511b230c7b65dd38297c0a1c32f1054db7083d Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 16 Jul 2025 22:47:42 +0900 Subject: [PATCH 17/76] add causal attn mask, a few fixes --- keras_hub/api/models/__init__.py | 24 ++ keras_hub/api/tokenizers/__init__.py | 6 + .../src/models/smollm3/smollm3_backbone.py | 44 ++- .../src/models/smollm3/smollm3_causal_lm.py | 325 ++++++++++++++++++ .../smollm3/smollm3_causal_lm_preprocessor.py | 84 +++++ .../src/models/smollm3/smollm3_layers.py | 172 ++++++--- .../src/models/smollm3/smollm3_tokenizer.py | 60 ++++ keras_hub/src/models/smollm3/smollm3_utils.py | 15 +- .../src/utils/transformers/convert_smollm3.py | 5 +- 9 files changed, 677 insertions(+), 58 deletions(-) create mode 100644 keras_hub/src/models/smollm3/smollm3_causal_lm.py create mode 100644 keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py create mode 100644 keras_hub/src/models/smollm3/smollm3_tokenizer.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9ad3aeb204..4ddcbc4521 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -576,6 +576,30 @@ from keras_hub.src.models.siglip.siglip_vision_encoder import ( SigLIPVisionEncoder as SigLIPVisionEncoder, ) +from keras_hub.src.models.smollm3.smollm3_backbone import ( + SmolLM3Backbone as SmolLM3Backbone, +) +from keras_hub.src.models.smollm3.smollm3_backbone import ( + SmolLM3Backbone as SmolLMBackbone, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm import ( + SmolLM3CausalLM as SmolLM3CausalLM, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm import ( + SmolLM3CausalLM as SmolLMCausalLM, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( + SmolLM3CausalLMPreprocessor as SmolLM3CausalLMPreprocessor, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( + SmolLM3CausalLMPreprocessor as SmolLMCausalLMPreprocessor, +) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLM3Tokenizer, +) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLMTokenizer, +) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( StableDiffusion3Backbone as StableDiffusion3Backbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..49b4eeab99 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -86,6 +86,12 @@ from keras_hub.src.models.siglip.siglip_tokenizer import ( SigLIPTokenizer as SigLIPTokenizer, ) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLM3Tokenizer, +) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLMTokenizer, +) from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer from keras_hub.src.models.whisper.whisper_tokenizer import ( WhisperTokenizer as WhisperTokenizer, diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 5b8848da22..11564b68d7 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -1,6 +1,9 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding @@ -68,12 +71,12 @@ def __init__( **kwargs, ): # === Layers === - self.token_embedding = keras.layers.Embedding( + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, name="token_embedding", ) - self.decoder_layers = [] + self.transformer_layers = [] for i in range(num_layers): layer = SmolLM3DecoderLayer( @@ -87,10 +90,10 @@ def __init__( layer_idx=i, intermediate_size=intermediate_dim, mlp_bias=mlp_bias, - rms_norm_epsilon=layer_norm_epsilon, + layer_norm_epsilon=layer_norm_epsilon, name=f"transformer_layer_{i}", ) - self.decoder_layers.append(layer) + self.transformer_layers.append(layer) self.norm = keras.layers.RMSNormalization( epsilon=layer_norm_epsilon, @@ -112,16 +115,20 @@ def __init__( position_id_input = keras.Input( shape=(None,), dtype="int32", name="position_ids" ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) hidden_states = self.token_embedding(token_id_input) position_embeddings = self.rotary_embedding( hidden_states, position_id_input ) - for decoder_layer in self.decoder_layers[:num_layers]: + for decoder_layer in self.transformer_layers[:num_layers]: hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, + decoder_padding_mask=padding_mask_input, **kwargs, ) @@ -130,6 +137,7 @@ def __init__( inputs={ "token_ids": token_id_input, "position_ids": position_id_input, + "padding_mask": padding_mask_input, }, outputs=sequence_output, **kwargs, @@ -137,14 +145,40 @@ def __init__( # === Config === self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_layer_enabled_list = rope_layer_enabled_list + self.layer_types = layer_types + self.mlp_bias = mlp_bias + self.layer_norm_epsilon = layer_norm_epsilon + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor def get_config(self): config = super().get_config() config.update( { "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "rope_layer_enabled_list": self.rope_layer_enabled_list, + "layer_types": self.layer_types, + "mlp_bias": self.mlp_bias, + "layer_norm_epsilon": self.layer_norm_epsilon, + "max_position_embeddings": self.max_position_embeddings, + "rope_theta": self.rope_theta, + "partial_rotary_factor": self.partial_rotary_factor, } ) return config diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py new file mode 100644 index 0000000000..965421d86f --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -0,0 +1,325 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( + SmolLM3CausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export( + [ + "keras_hub.models.SmolLM3CausalLM", + "keras_hub.models.SmolLMCausalLM", + ] +) +class SmolLM3CausalLM(CausalLM): + backbone_cls = SmolLM3Backbone + preprocessor_cls = SmolLM3CausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + hidden_states = backbone(inputs) + # Use self.backbone for clarity and consistency + outputs = self.backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + position_ids, + cache, + cache_update_index, + ): + """Forward pass of `SmolLM3CausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, 1)`. + (For generation, this is typically a single new token.) + cache: a dense float Tensor, the cache of key and value. + Shape: (batch_size, num_layers, 2, max_seq_len, num_key_value_heads, head_dim) + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + training: Boolean, whether the call is during training or inference. + attention_mask: Optional attention mask. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + position_embeddings = self.backbone.rotary_embedding(x, position_ids) + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + position_embeddings=position_embeddings, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids, position_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads + shape = [ + batch_size, + num_layers, + 2, + num_key_value_heads, + max_length, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + index = ops.convert_to_tensor(0, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache( + token_ids, position_ids, cache, index + ) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + position_ids = ops.arange(token_ids.shape[0]) + + hidden_states, cache = self._build_cache(token_ids, position_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + position_ids, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `SmolLM3CausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `SmolLM3Backbone` and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + smol_lm = keras_hub.models.SmolLM3CausalLM.from_preset("...") + generations = smol_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = smol_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = smol_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + # Generate position_ids for the full sequence + seq_len = ops.shape(token_ids)[1] + position_ids = ops.arange(0, seq_len, dtype="int32")[None, :] + position_ids = ops.broadcast_to(position_ids, (batch_shape[0], seq_len)) + + # Get position embeddings for the full sequence + position_embeddings = self.backbone.rotary_embedding(x, position_ids) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer( + hidden_states=x, + position_embeddings=position_embeddings, + attention_mask=padding_mask, + ) + x = layer_intercept_fn(x, i) + + x = self.backbone.norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py b/keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py new file mode 100644 index 0000000000..432519829f --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py @@ -0,0 +1,84 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.models.smollm3.smollm3_tokenizer import SmolLM3Tokenizer + + +@keras_hub_export( + [ + "keras_hub.models.SmolLMCausalLMPreprocessor", + "keras_hub.models.SmolLM3CausalLMPreprocessor", + ] +) +class SmolLM3CausalLMPreprocessor(CausalLMPreprocessor): + """SmolLM3 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.SmolLM3CausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.SmolLM3CausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.SmolLM3Tokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.SmolLM3CausalLMPreprocessor.from_preset( + "..." + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("...") + preprocessor(sentence) + # Same output. + preprocessor("...") + + # Tokenize a batch of sentences. + sentences = tf.constant(["...", "..."]) + preprocessor(sentences) + # Same output. + preprocessor(["...", "..."]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "...", + "...", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + backbone_cls = SmolLM3Backbone + tokenizer_cls = SmolLM3Tokenizer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 2c60174324..d51b6f7796 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -6,6 +6,9 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( compute_causal_mask, ) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward from keras_hub.src.models.smollm3.smollm3_utils import rope_init @@ -103,8 +106,8 @@ def call( self, hidden_states, position_embeddings, - attention_mask, training=False, + attention_mask=None, **kwargs, ): """ @@ -117,34 +120,62 @@ def call( training: Whether the layer is in training mode. """ self.training = training + self_attention_cache = kwargs.get("self_attention_cache", None) + self_attention_cache_update_index = kwargs.get( + "self_attention_cache_update_index", None + ) - input_shape = ops.shape(hidden_states)[ - :-1 - ] # Exclude last dim (hidden_size) + input_shape = ops.shape(hidden_states)[:-1] + hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) - query_states = ops.reshape( - self.q_proj(hidden_states), - (*input_shape, self.num_attention_heads, self.head_dim), - ) - query_states = ops.transpose( - query_states, axes=(0, 2, 1, 3) - ) # (batch, num_heads, seq_len, head_dim) - - # For key and value, the kv_hidden_shape should be based on num_key_value_heads - kv_hidden_shape = ( - *input_shape, - self.num_key_value_heads, - self.head_dim, - ) - key_states = ops.reshape(self.k_proj(hidden_states), kv_hidden_shape) - key_states = ops.transpose( - key_states, axes=(0, 2, 1, 3) - ) # (batch, num_key_value_heads, seq_len, head_dim) + query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) + # (batch, num_heads, seq_len, head_dim) + query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) - value_states = ops.reshape(self.v_proj(hidden_states), kv_hidden_shape) - value_states = ops.transpose( - value_states, axes=(0, 2, 1, 3) - ) # (batch, num_key_value_heads, seq_len, head_dim) + def _compute_kv_values(x_input): + kv_hidden_shape = ( + *input_shape, + self.num_key_value_heads, + self.head_dim, + ) + + key_states_raw = ops.reshape(self.k_proj(x_input), kv_hidden_shape) + value_states_raw = ops.reshape( + self.v_proj(x_input), kv_hidden_shape + ) + + key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3)) + value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3)) + return key_states, value_states + + if self_attention_cache is not None: + key_cache = self_attention_cache[:, 0, ...] + value_cache = self_attention_cache[:, 1, ...] + + if self_attention_cache_update_index is None: + key_states = key_cache + value_states = value_cache + else: + key_update, value_update = _compute_kv_values(hidden_states) + update_idx_tensor = ops.convert_to_tensor( + self_attention_cache_update_index, dtype="int32" + ) + start = [0, 0, update_idx_tensor, 0] + key_states = ops.slice_update(key_cache, start, key_update) + value_states = ops.slice_update( + value_cache, start, value_update + ) + self_attention_cache = ops.stack( + (key_states, value_states), axis=1 + ) + else: + if self_attention_cache_update_index is not None: + raise ValueError( + "`self_attention_cache_update_index` should not be set if `self_attention_cache` is " + f"`None`. Received: self_attention_cache={self_attention_cache}, " + f"self_attention_cache_update_index={self_attention_cache_update_index}" + ) + key_states, value_states = _compute_kv_values(hidden_states) if self.use_rope: cos, sin = position_embeddings @@ -152,23 +183,25 @@ def call( query_states, key_states, cos, sin ) - attn_output, attn_weights = eager_attention_forward( + attn_output = eager_attention_forward( module=self, query=query_states, key=key_states, value=value_states, - attention_mask=attention_mask, dropout=self.attention_dropout, scaling=self.scaling, training=self.training, - **kwargs, + attention_mask=attention_mask, ) attn_output = ops.reshape(attn_output, (*input_shape, self.hidden_size)) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + if self_attention_cache is not None: + return attn_output, self_attention_cache + + return attn_output def compute_output_shape(self, input_shape): """ @@ -290,7 +323,7 @@ class SmolLM3DecoderLayer(layers.Layer): layer_idx: Index of the current layer. intermediate_size: The intermediate size of the MLP. mlp_bias: Whether to use bias in MLP dense layers. - rms_norm_epsilon: Epsilon for RMSNormalization. + layer_norm_epsilon: Epsilon for RMSNormalization. """ def __init__( @@ -305,7 +338,7 @@ def __init__( layer_idx: int, intermediate_size: int, mlp_bias: bool, - rms_norm_epsilon: float, + layer_norm_epsilon: float, **kwargs, ): super().__init__(**kwargs) @@ -332,14 +365,49 @@ def __init__( ) self.input_layernorm = layers.RMSNormalization( - epsilon=rms_norm_epsilon, axis=-1, name="input_layernorm" + epsilon=layer_norm_epsilon, axis=-1, name="input_layernorm" ) self.post_attention_layernorm = layers.RMSNormalization( - epsilon=rms_norm_epsilon, axis=-1, name="post_attention_layernorm" + epsilon=layer_norm_epsilon, axis=-1, name="post_attention_layernorm" ) self.attention_type = layer_types[layer_idx] + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + def build(self, input_shape): """ Builds the sub-layers based on the input shape. @@ -375,6 +443,8 @@ def call( hidden_states, position_embeddings=None, training=False, + decoder_padding_mask=None, + decoder_attention_mask=None, **kwargs, ): """ @@ -385,23 +455,36 @@ def call( position_embeddings: Optional tuple of (cos, sin) tensors for RoPE. training: Whether the layer is in training mode. """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + self_attention_cache = kwargs.get("self_attention_cache", None) + self_attention_cache_update_index = kwargs.get( + "self_attention_cache_update_index", None + ) - attention_mask = compute_causal_mask( - ops.shape(hidden_states)[0], - ops.shape(hidden_states)[1], - ops.shape(hidden_states)[1], + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=hidden_states, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, ) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention - attn_output, _ = self.self_attn( + x = self.self_attn( hidden_states=hidden_states, - attention_mask=attention_mask, position_embeddings=position_embeddings, training=training, + attention_mask=self_attention_mask, **kwargs, ) + + if isinstance(x, tuple): + attn_output, self_attention_cache = x + else: + attn_output = x + hidden_states = ops.add(residual, attn_output) residual = hidden_states @@ -409,7 +492,10 @@ def call( hidden_states = self.mlp(hidden_states) hidden_states = ops.add(residual, hidden_states) - return hidden_states + if self_attention_cache is not None: + return hidden_states, self_attention_cache + else: + return hidden_states def compute_output_shape(self, input_shape): """ diff --git a/keras_hub/src/models/smollm3/smollm3_tokenizer.py b/keras_hub/src/models/smollm3/smollm3_tokenizer.py new file mode 100644 index 0000000000..c1df7c5eb4 --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_tokenizer.py @@ -0,0 +1,60 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + [ + "keras_hub.tokenizers.SmolLM3Tokenizer", + "keras_hub.tokenizers.SmolLMTokenizer", + "keras_hub.models.SmolLM3Tokenizer", + "keras_hub.models.SmolLMTokenizer", + ] +) +class SmolLM3Tokenizer(BytePairTokenizer): + """Tokenizer for SmolLM3 models. + + This tokenizer implements byte-pair encoding (BPE) for SmolLM3 models, + handling special tokens like BOS (beginning of sequence) and EOS (end of + sequence). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|endoftext|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + """ + + backbone_cls = SmolLM3Backbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|end_of_text|>" + self._add_special_token(eos_token, "end_token") + + bos_token = "<|begin_of_text|>" + self._add_special_token(bos_token, "bos_token") + + start_think_token = "" + self._add_special_token(start_think_token, "start_think_token") + + end_think_token = "" + self._add_special_token(end_think_token, "end_think_token") + + self.start_token_id = None + self.start_token = None + self.pad_token_id = 0 + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index bb852a6915..4861e5a38b 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -3,8 +3,8 @@ def rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., : ops.shape(x)[-1] // 2] + x2 = x[..., ops.shape(x)[-1] // 2 :] return ops.concatenate((-x2, x1), axis=-1) @@ -33,8 +33,8 @@ def eager_attention_forward( query, key, value, - attention_mask, scaling, + attention_mask=None, dropout=0.0, training=False, ): @@ -46,17 +46,18 @@ def eager_attention_forward( * scaling ) - # Apply attention mask if provided if attention_mask is not None: - attn_weights = ops.add(attn_weights, attention_mask) + causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] + attn_weights = ops.add(attn_weights, causal_mask) attn_weights = ops.softmax(attn_weights, axis=-1) - if not training: + + if training: attn_weights = random.dropout(attn_weights, rate=dropout) attn_output = ops.matmul(attn_weights, value_states) attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3)) - return attn_output, attn_weights + return attn_output def rope_init(rope_theta: float, partial_rotary_factor: float, head_dim: int): diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py index d93e0c1361..46ff0cc2ce 100644 --- a/keras_hub/src/utils/transformers/convert_smollm3.py +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -52,7 +52,6 @@ def transpose_and_reshape(x, shape): ) # Attention layers - ## Query loader.port_weight( keras_variable=decoder_layer.self_attn.q_proj.kernel, @@ -75,8 +74,6 @@ def transpose_and_reshape(x, shape): loader.port_weight( keras_variable=decoder_layer.self_attn.o_proj.kernel, hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", - # rearrange_patterns="c (a b) -> a b c", - # rearrange_dims={"a": backbone.num_query_heads}, hook_fn=transpose_and_reshape, ) @@ -112,6 +109,8 @@ def transpose_and_reshape(x, shape): hf_weight_key="model.norm.weight", ) + backbone.training = False + return backbone From d81e83156c4408eb479d004437daa9564561c7bf Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 17:04:57 +0900 Subject: [PATCH 18/76] add softmax op --- keras_hub/src/models/smollm3/smollm3_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 4861e5a38b..2bcd8d64f7 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -1,3 +1,4 @@ +from keras import layers from keras import ops from keras import random @@ -38,6 +39,12 @@ def eager_attention_forward( dropout=0.0, training=False, ): + softmax_op = layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -47,10 +54,9 @@ def eager_attention_forward( ) if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]] - attn_weights = ops.add(attn_weights, causal_mask) - - attn_weights = ops.softmax(attn_weights, axis=-1) + attn_weights = softmax_op(attn_weights, attention_mask[:, None, :, :]) + else: + attn_weights = softmax_op(attn_weights) if training: attn_weights = random.dropout(attn_weights, rate=dropout) From e07e8487f3512dca3cf2e2e8ad51345b65d6a8eb Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 19:45:05 +0900 Subject: [PATCH 19/76] fix build cache shape? --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 2 +- keras_hub/src/models/smollm3/smollm3_layers.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 965421d86f..d28534ef5d 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -97,8 +97,8 @@ def _build_cache(self, token_ids, position_ids): batch_size, num_layers, 2, - num_key_value_heads, max_length, + num_key_value_heads, head_dim, ] cache = ops.zeros(shape, dtype=self.compute_dtype) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index d51b6f7796..5aad49cfb5 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -381,6 +381,20 @@ def _compute_self_attention_mask( self_attention_cache, self_attention_cache_update_index, ): + """Computes the self-attention mask combining causal, padding and + attention masks. + + Args: + decoder_sequence: Input tensor. + decoder_padding_mask: Mask tensor for padding tokens. + decoder_attention_mask: Additional attention mask. + self_attention_cache: Optional cached key and value tensors. + self_attention_cache_update_index: Index at which to update the + cache. + + Returns: + Combined attention mask tensor. + """ decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask ) From e25fcdd83e285a49915cba1be57cda528136b3a4 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 19:53:43 +0900 Subject: [PATCH 20/76] fix shape positioning in cache update --- keras_hub/src/models/smollm3/smollm3_layers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 5aad49cfb5..e1e84cd91c 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -160,11 +160,9 @@ def _compute_kv_values(x_input): update_idx_tensor = ops.convert_to_tensor( self_attention_cache_update_index, dtype="int32" ) - start = [0, 0, update_idx_tensor, 0] + start = [0, update_idx_tensor, 0, 0] key_states = ops.slice_update(key_cache, start, key_update) - value_states = ops.slice_update( - value_cache, start, value_update - ) + value_states = ops.slice_update(value_cache, start, value_update) self_attention_cache = ops.stack( (key_states, value_states), axis=1 ) From 5a49ed6093cf34464e3a1f4b09671fe616351299 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 20:11:10 +0900 Subject: [PATCH 21/76] Remove position ids as input --- .../src/models/smollm3/smollm3_backbone.py | 15 +++++---- .../src/models/smollm3/smollm3_causal_lm.py | 32 +++++++++++++------ .../src/models/smollm3/smollm3_layers.py | 5 +-- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 11564b68d7..2d246b5244 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -7,7 +7,7 @@ from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding - +from keras import ops @keras_hub_export( [ @@ -112,16 +112,20 @@ def __init__( token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" ) - position_id_input = keras.Input( - shape=(None,), dtype="int32", name="position_ids" - ) + padding_mask_input = keras.Input( shape=(None,), dtype="int32", name="padding_mask" ) + # Infer position IDs from the shape of token IDs. + seq_len = ops.shape(token_id_input)[1] + position_ids = ops.arange(0, seq_len, dtype="int32") + # Add a batch dimension to broadcast. + position_ids = ops.expand_dims(position_ids, axis=0) + hidden_states = self.token_embedding(token_id_input) position_embeddings = self.rotary_embedding( - hidden_states, position_id_input + hidden_states, position_ids ) for decoder_layer in self.transformer_layers[:num_layers]: @@ -136,7 +140,6 @@ def __init__( super().__init__( inputs={ "token_ids": token_id_input, - "position_ids": position_id_input, "padding_mask": padding_mask_input, }, outputs=sequence_output, diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index d28534ef5d..d8385a13d4 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -41,7 +41,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): def call_with_cache( self, token_ids, - position_ids, cache, cache_update_index, ): @@ -53,8 +52,9 @@ def call_with_cache( and avoids recomputing the outputs of seen tokens. Args: - token_ids: a dense int Tensor with shape `(batch_size, 1)`. - (For generation, this is typically a single new token.) + token_ids: a dense int Tensor with shape `(batch_size, seq_len)`. + For prefill, `seq_len` is the prompt length. For generation, + `seq_len` is typically 1. cache: a dense float Tensor, the cache of key and value. Shape: (batch_size, num_layers, 2, max_seq_len, num_key_value_heads, head_dim) cache_update_index: int, or int Tensor. The index of current inputs @@ -69,6 +69,20 @@ def call_with_cache( the decoding cache. """ x = self.backbone.token_embedding(token_ids) + + # Infer position_ids based on the input shape. + seq_len = ops.shape(token_ids)[1] + if seq_len > 1: + # Prefill stage for the initial prompt. + position_ids = ops.arange(0, seq_len, dtype="int32") + position_ids = ops.expand_dims(position_ids, axis=0) + else: + # Decoding stage for a single token. + batch_size = ops.shape(token_ids)[0] + position_ids = ops.full( + (batch_size, 1), cache_update_index, dtype="int32" + ) + # Each decoder layer has a cache; we update them separately. position_embeddings = self.backbone.rotary_embedding(x, position_ids) updated_cache = [] @@ -86,7 +100,7 @@ def call_with_cache( logits = self.backbone.token_embedding(x, reverse=True) return logits, hidden_states, cache - def _build_cache(self, token_ids, position_ids): + def _build_cache(self, token_ids): """Build an empty cache for use with `call_with_cache()`.""" batch_size = ops.shape(token_ids)[0] max_length = ops.shape(token_ids)[1] @@ -102,10 +116,10 @@ def _build_cache(self, token_ids, position_ids): head_dim, ] cache = ops.zeros(shape, dtype=self.compute_dtype) - index = ops.convert_to_tensor(0, dtype=self.compute_dtype) + index = ops.convert_to_tensor(0, dtype="int32") # Seed the cache. _, hidden_states, cache = self.call_with_cache( - token_ids, position_ids, cache, index + token_ids, cache, index ) return hidden_states, cache @@ -128,9 +142,8 @@ def generate_step( will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] - position_ids = ops.arange(token_ids.shape[0]) - hidden_states, cache = self._build_cache(token_ids, position_ids) + hidden_states, cache = self._build_cache(token_ids) # Compute the lengths of all user inputted tokens ids. row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) # Start at the first index that has no user inputted id. @@ -143,7 +156,6 @@ def next(prompt, cache, index): prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, - position_ids, cache, cache_update_index, ) @@ -322,4 +334,4 @@ def default_layer_intercept_fn(x, unused_i): from_logits=True, reduction="none" ) per_token_loss = per_token_loss_fn(target_ids, logits) - return per_token_loss + return per_token_loss \ No newline at end of file diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index e1e84cd91c..ecb80a0f3b 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -157,10 +157,7 @@ def _compute_kv_values(x_input): value_states = value_cache else: key_update, value_update = _compute_kv_values(hidden_states) - update_idx_tensor = ops.convert_to_tensor( - self_attention_cache_update_index, dtype="int32" - ) - start = [0, update_idx_tensor, 0, 0] + start = [0, self_attention_cache_update_index, 0, 0] key_states = ops.slice_update(key_cache, start, key_update) value_states = ops.slice_update(value_cache, start, value_update) self_attention_cache = ops.stack( From 89391d9335a12ae5727de34f408922e4d58adcb9 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 20:20:46 +0900 Subject: [PATCH 22/76] use sampler's max length --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index d8385a13d4..13bd523b15 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -103,7 +103,8 @@ def call_with_cache( def _build_cache(self, token_ids): """Build an empty cache for use with `call_with_cache()`.""" batch_size = ops.shape(token_ids)[0] - max_length = ops.shape(token_ids)[1] + #max_length = ops.shape(token_ids)[1] + max_length = self.sampler.max_length num_layers = self.backbone.num_layers num_key_value_heads = self.backbone.num_key_value_heads head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads From 7a9d99cc872aef02a0e92dfc4ab77d4b26a6efd3 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 20:36:37 +0900 Subject: [PATCH 23/76] format --- keras_hub/src/models/smollm3/smollm3_backbone.py | 7 +++---- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 9 +++------ keras_hub/src/models/smollm3/smollm3_layers.py | 4 +++- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 2d246b5244..4d6401707f 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -1,4 +1,5 @@ import keras +from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.modeling.reversible_embedding import ( @@ -7,7 +8,7 @@ from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding -from keras import ops + @keras_hub_export( [ @@ -124,9 +125,7 @@ def __init__( position_ids = ops.expand_dims(position_ids, axis=0) hidden_states = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding( - hidden_states, position_ids - ) + position_embeddings = self.rotary_embedding(hidden_states, position_ids) for decoder_layer in self.transformer_layers[:num_layers]: hidden_states = decoder_layer( diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 13bd523b15..e1f324ec56 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -103,8 +103,7 @@ def call_with_cache( def _build_cache(self, token_ids): """Build an empty cache for use with `call_with_cache()`.""" batch_size = ops.shape(token_ids)[0] - #max_length = ops.shape(token_ids)[1] - max_length = self.sampler.max_length + max_length = ops.shape(token_ids)[1] num_layers = self.backbone.num_layers num_key_value_heads = self.backbone.num_key_value_heads head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads @@ -119,9 +118,7 @@ def _build_cache(self, token_ids): cache = ops.zeros(shape, dtype=self.compute_dtype) index = ops.convert_to_tensor(0, dtype="int32") # Seed the cache. - _, hidden_states, cache = self.call_with_cache( - token_ids, cache, index - ) + _, hidden_states, cache = self.call_with_cache(token_ids, cache, index) return hidden_states, cache def generate_step( @@ -335,4 +332,4 @@ def default_layer_intercept_fn(x, unused_i): from_logits=True, reduction="none" ) per_token_loss = per_token_loss_fn(target_ids, logits) - return per_token_loss \ No newline at end of file + return per_token_loss diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index ecb80a0f3b..b784a5db29 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -159,7 +159,9 @@ def _compute_kv_values(x_input): key_update, value_update = _compute_kv_values(hidden_states) start = [0, self_attention_cache_update_index, 0, 0] key_states = ops.slice_update(key_cache, start, key_update) - value_states = ops.slice_update(value_cache, start, value_update) + value_states = ops.slice_update( + value_cache, start, value_update + ) self_attention_cache = ops.stack( (key_states, value_states), axis=1 ) From e3067a5a6ad8989e3d5196438d21bf6bcb9ba800 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:20:49 +0900 Subject: [PATCH 24/76] add logs --- keras_hub/src/models/smollm3/smollm3_layers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index b784a5db29..cfc953a994 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -151,13 +151,18 @@ def _compute_kv_values(x_input): if self_attention_cache is not None: key_cache = self_attention_cache[:, 0, ...] value_cache = self_attention_cache[:, 1, ...] + print("key_cache", key_cache.shape) + print("value_cache", value_cache.shape) if self_attention_cache_update_index is None: key_states = key_cache value_states = value_cache else: key_update, value_update = _compute_kv_values(hidden_states) + print("key_update", key_update.shape) + print("value_update", value_update.shape) start = [0, self_attention_cache_update_index, 0, 0] + print("start", start) key_states = ops.slice_update(key_cache, start, key_update) value_states = ops.slice_update( value_cache, start, value_update From 76223151eff8784712d77713eacf577d864c9b0f Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:24:14 +0900 Subject: [PATCH 25/76] switch order or value heads and max length --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index e1f324ec56..e479f51124 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -111,8 +111,8 @@ def _build_cache(self, token_ids): batch_size, num_layers, 2, - max_length, num_key_value_heads, + max_length, head_dim, ] cache = ops.zeros(shape, dtype=self.compute_dtype) From 982a5465ba3e9747614d0839939aab2c4aeef8dd Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:33:46 +0900 Subject: [PATCH 26/76] oh god please --- keras_hub/src/models/smollm3/smollm3_layers.py | 13 ++++++++----- keras_hub/src/models/smollm3/smollm3_utils.py | 1 + 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index cfc953a994..c8f0bea724 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -418,11 +418,14 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, cache_update_index ) - return ( - ops.minimum(decoder_mask, causal_mask) - if decoder_mask is not None - else causal_mask - ) + if decoder_mask is not None: + # Expand decoder mask from [batch, tgt_len] to [batch, tgt_len, input_len] + # This is done by broadcasting + decoder_mask = ops.expand_dims(decoder_mask, axis=-1) + decoder_mask = ops.broadcast_to(decoder_mask, ops.shape(causal_mask)) + return ops.minimum(decoder_mask, causal_mask) + + return causal_mask def build(self, input_shape): """ diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 2bcd8d64f7..95f72861dc 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -54,6 +54,7 @@ def eager_attention_forward( ) if attention_mask is not None: + attention_mask = attention_mask.expand(-1, attn_weights.shape[1], -1, -1) attn_weights = softmax_op(attn_weights, attention_mask[:, None, :, :]) else: attn_weights = softmax_op(attn_weights) From 7319f48e0415a47dc637f8956b7ee6193447d448 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:37:12 +0900 Subject: [PATCH 27/76] oh god please --- keras_hub/src/models/smollm3/smollm3_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 95f72861dc..aae2dd39e5 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -53,8 +53,10 @@ def eager_attention_forward( * scaling ) + print("attn_weights", attn_weights.shape) + print("attention_mask", attention_mask.shape) + if attention_mask is not None: - attention_mask = attention_mask.expand(-1, attn_weights.shape[1], -1, -1) attn_weights = softmax_op(attn_weights, attention_mask[:, None, :, :]) else: attn_weights = softmax_op(attn_weights) From 3c3d7fbd11e0ecaa8381fb8a943f619dc3f20431 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:40:51 +0900 Subject: [PATCH 28/76] oh god please --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index c8f0bea724..e74378bc4a 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -406,7 +406,7 @@ def _compute_self_attention_mask( # decoding. For generative inference, `decoder_sequence` will # generally be length 1, and `cache` will be the full generation length. if self_attention_cache is not None: - input_length = ops.shape(self_attention_cache)[2] + input_length = ops.shape(self_attention_cache)[3] cache_update_index = ( 0 From 8046d4ba97d47aafec66686895c99818cf39cef4 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:45:47 +0900 Subject: [PATCH 29/76] oh god please --- .../src/models/smollm3/smollm3_layers.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index e74378bc4a..61e087d05c 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -374,7 +374,7 @@ def __init__( ) self.attention_type = layer_types[layer_idx] - + def _compute_self_attention_mask( self, decoder_sequence, @@ -401,16 +401,15 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask ) batch_size = ops.shape(decoder_sequence)[0] - input_length = output_length = ops.shape(decoder_sequence)[1] - # We need to handle a rectangular causal mask when doing cached - # decoding. For generative inference, `decoder_sequence` will - # generally be length 1, and `cache` will be the full generation length. + output_length = ops.shape(decoder_sequence)[1] + input_length = output_length # Default if no cache is present + if self_attention_cache is not None: + # shape: [batch, 2, num_heads, key_len, head_dim] input_length = ops.shape(self_attention_cache)[3] cache_update_index = ( - 0 - if self_attention_cache_update_index is None + 0 if self_attention_cache_update_index is None else self_attention_cache_update_index ) @@ -418,14 +417,13 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, cache_update_index ) - if decoder_mask is not None: - # Expand decoder mask from [batch, tgt_len] to [batch, tgt_len, input_len] - # This is done by broadcasting - decoder_mask = ops.expand_dims(decoder_mask, axis=-1) - decoder_mask = ops.broadcast_to(decoder_mask, ops.shape(causal_mask)) - return ops.minimum(decoder_mask, causal_mask) + # Combine causal and user-provided masks + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) - return causal_mask def build(self, input_shape): """ From 2d4a3b567399321fe8f810c822733d83e91f4f4a Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:54:30 +0900 Subject: [PATCH 30/76] oh god please --- keras_hub/src/models/smollm3/smollm3_layers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 61e087d05c..b291d79e51 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -161,7 +161,7 @@ def _compute_kv_values(x_input): key_update, value_update = _compute_kv_values(hidden_states) print("key_update", key_update.shape) print("value_update", value_update.shape) - start = [0, self_attention_cache_update_index, 0, 0] + start = [0, 0, self_attention_cache_update_index, 0] print("start", start) key_states = ops.slice_update(key_cache, start, key_update) value_states = ops.slice_update( @@ -374,7 +374,7 @@ def __init__( ) self.attention_type = layer_types[layer_idx] - + def _compute_self_attention_mask( self, decoder_sequence, @@ -402,10 +402,10 @@ def _compute_self_attention_mask( ) batch_size = ops.shape(decoder_sequence)[0] output_length = ops.shape(decoder_sequence)[1] - input_length = output_length # Default if no cache is present + input_length = output_length if self_attention_cache is not None: - # shape: [batch, 2, num_heads, key_len, head_dim] + # [batch, 2, num_heads, key_len, head_dim] input_length = ops.shape(self_attention_cache)[3] cache_update_index = ( @@ -417,14 +417,12 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, cache_update_index ) - # Combine causal and user-provided masks return ( ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None else causal_mask ) - def build(self, input_shape): """ Builds the sub-layers based on the input shape. From 53efb595e34362ffe96a49006fa0010835b8aff9 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 26 Jul 2025 21:57:43 +0900 Subject: [PATCH 31/76] god has answered my prayers --- keras_hub/src/models/smollm3/smollm3_layers.py | 8 ++------ keras_hub/src/models/smollm3/smollm3_utils.py | 3 --- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index b291d79e51..097e8161dd 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -151,18 +151,13 @@ def _compute_kv_values(x_input): if self_attention_cache is not None: key_cache = self_attention_cache[:, 0, ...] value_cache = self_attention_cache[:, 1, ...] - print("key_cache", key_cache.shape) - print("value_cache", value_cache.shape) if self_attention_cache_update_index is None: key_states = key_cache value_states = value_cache else: key_update, value_update = _compute_kv_values(hidden_states) - print("key_update", key_update.shape) - print("value_update", value_update.shape) start = [0, 0, self_attention_cache_update_index, 0] - print("start", start) key_states = ops.slice_update(key_cache, start, key_update) value_states = ops.slice_update( value_cache, start, value_update @@ -409,7 +404,8 @@ def _compute_self_attention_mask( input_length = ops.shape(self_attention_cache)[3] cache_update_index = ( - 0 if self_attention_cache_update_index is None + 0 + if self_attention_cache_update_index is None else self_attention_cache_update_index ) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index aae2dd39e5..2bcd8d64f7 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -53,9 +53,6 @@ def eager_attention_forward( * scaling ) - print("attn_weights", attn_weights.shape) - print("attention_mask", attention_mask.shape) - if attention_mask is not None: attn_weights = softmax_op(attn_weights, attention_mask[:, None, :, :]) else: From c136080394b5c66a37e74132a04f7988367f69af Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:06:31 +0900 Subject: [PATCH 32/76] Simplify position ids --- keras_hub/src/models/smollm3/smollm3_backbone.py | 1 - keras_hub/src/models/smollm3/smollm3_causal_lm.py | 14 ++------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 4d6401707f..74f3d3bc5a 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -78,7 +78,6 @@ def __init__( name="token_embedding", ) self.transformer_layers = [] - for i in range(num_layers): layer = SmolLM3DecoderLayer( hidden_size=hidden_dim, diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index e479f51124..1262c60298 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -30,8 +30,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs): # rather than "backbone.inputs" which is the flattened list of inputs. inputs = backbone.input hidden_states = backbone(inputs) - # Use self.backbone for clarity and consistency - outputs = self.backbone.token_embedding(hidden_states, reverse=True) + outputs = backbone.token_embedding(hidden_states, reverse=True) super().__init__( inputs=inputs, outputs=outputs, @@ -72,16 +71,7 @@ def call_with_cache( # Infer position_ids based on the input shape. seq_len = ops.shape(token_ids)[1] - if seq_len > 1: - # Prefill stage for the initial prompt. - position_ids = ops.arange(0, seq_len, dtype="int32") - position_ids = ops.expand_dims(position_ids, axis=0) - else: - # Decoding stage for a single token. - batch_size = ops.shape(token_ids)[0] - position_ids = ops.full( - (batch_size, 1), cache_update_index, dtype="int32" - ) + position_ids = ops.arange(0, seq_len, dtype="int32") # Each decoder layer has a cache; we update them separately. position_embeddings = self.backbone.rotary_embedding(x, position_ids) From 7b7ebbbbe5476d630f46d72611d7fbebf6938121 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:09:41 +0900 Subject: [PATCH 33/76] Simplify position ids --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 1262c60298..a5520db9f3 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -71,7 +71,10 @@ def call_with_cache( # Infer position_ids based on the input shape. seq_len = ops.shape(token_ids)[1] + batch_size = ops.shape(token_ids)[0] position_ids = ops.arange(0, seq_len, dtype="int32") + position_ids = ops.expand_dims(position_ids, axis=0) # (1, seq_len) + position_ids = ops.broadcast_to(position_ids, (batch_size, seq_len)) # Each decoder layer has a cache; we update them separately. position_embeddings = self.backbone.rotary_embedding(x, position_ids) From 4148384eea3edabc9b2f61dde62d63d4f11eefdc Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:30:27 +0900 Subject: [PATCH 34/76] Use existing rotary embeddings --- .../src/models/smollm3/smollm3_backbone.py | 18 ++++++++++++------ .../src/utils/transformers/convert_smollm3.py | 1 + 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 74f3d3bc5a..b807597a18 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -8,6 +8,7 @@ from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding @keras_hub_export( @@ -69,6 +70,7 @@ def __init__( max_position_embeddings, rope_theta, partial_rotary_factor, + rope_scaling, **kwargs, ): # === Layers === @@ -100,12 +102,16 @@ def __init__( name="sequence_output_layernorm", ) - self.rotary_embedding = SmolLM3RotaryEmbedding( - hidden_size=hidden_dim, - num_attention_heads=num_attention_heads, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - partial_rotary_factor=partial_rotary_factor, + #self.rotary_embedding = SmolLM3RotaryEmbedding( + # hidden_size=hidden_dim, + # num_attention_heads=num_attention_heads, + # max_position_embeddings=max_position_embeddings, + # rope_theta=rope_theta, + # partial_rotary_factor=partial_rotary_factor, + #) + self.rotary_embedding = RotaryEmbedding( + max_wavelength=rope_theta, + scaling_factor=rope_scaling ) # === Functional Model === diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py index 46ff0cc2ce..b8b08e9a26 100644 --- a/keras_hub/src/utils/transformers/convert_smollm3.py +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -30,6 +30,7 @@ def convert_backbone_config(transformers_config): "rope_layer_enabled_list": transformers_config["no_rope_layers"], "layer_types": transformers_config["layer_types"], "mlp_bias": transformers_config["mlp_bias"], + "rope_scaling": transformers_config["rope_scaling"] } From d9a0f7aed2c0496b5222fb3b81c286782ce5481b Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:31:45 +0900 Subject: [PATCH 35/76] Use existing rotary embeddings --- keras_hub/src/models/smollm3/smollm3_backbone.py | 8 +------- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 9 +-------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index b807597a18..5f5dc30878 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -123,14 +123,8 @@ def __init__( shape=(None,), dtype="int32", name="padding_mask" ) - # Infer position IDs from the shape of token IDs. - seq_len = ops.shape(token_id_input)[1] - position_ids = ops.arange(0, seq_len, dtype="int32") - # Add a batch dimension to broadcast. - position_ids = ops.expand_dims(position_ids, axis=0) - hidden_states = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(hidden_states, position_ids) + position_embeddings = self.rotary_embedding(hidden_states) for decoder_layer in self.transformer_layers[:num_layers]: hidden_states = decoder_layer( diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index a5520db9f3..b4da91d1b1 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -69,15 +69,8 @@ def call_with_cache( """ x = self.backbone.token_embedding(token_ids) - # Infer position_ids based on the input shape. - seq_len = ops.shape(token_ids)[1] - batch_size = ops.shape(token_ids)[0] - position_ids = ops.arange(0, seq_len, dtype="int32") - position_ids = ops.expand_dims(position_ids, axis=0) # (1, seq_len) - position_ids = ops.broadcast_to(position_ids, (batch_size, seq_len)) - # Each decoder layer has a cache; we update them separately. - position_embeddings = self.backbone.rotary_embedding(x, position_ids) + position_embeddings = self.backbone.rotary_embedding(x) updated_cache = [] for i in range(self.backbone.num_layers): current_cache = cache[:, i, ...] From 58e87f6f4e2384c46c839915a15780096545b648 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:35:18 +0900 Subject: [PATCH 36/76] Use existing rotary embeddings --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index b4da91d1b1..a3647cf2e9 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -292,13 +292,8 @@ def default_layer_intercept_fn(x, unused_i): token_embeddings = self.backbone.token_embedding(token_ids) x = layer_intercept_fn(token_embeddings, -1) - # Generate position_ids for the full sequence - seq_len = ops.shape(token_ids)[1] - position_ids = ops.arange(0, seq_len, dtype="int32")[None, :] - position_ids = ops.broadcast_to(position_ids, (batch_shape[0], seq_len)) - # Get position embeddings for the full sequence - position_embeddings = self.backbone.rotary_embedding(x, position_ids) + position_embeddings = self.backbone.rotary_embedding(x) for i, transformer_layer in enumerate(self.backbone.transformer_layers): x = transformer_layer( From 5a6fb2763a9ba382010d2586eceef4b6f5fe9e2f Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:38:52 +0900 Subject: [PATCH 37/76] pass dtype policy --- keras_hub/src/models/smollm3/smollm3_backbone.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 5f5dc30878..31e587ceee 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -111,7 +111,8 @@ def __init__( #) self.rotary_embedding = RotaryEmbedding( max_wavelength=rope_theta, - scaling_factor=rope_scaling + scaling_factor=rope_scaling, + dtype=self.dtype_policy ) # === Functional Model === From e17dd996d50a412ca158e1e27539774a24bc6a52 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:42:29 +0900 Subject: [PATCH 38/76] pass dtype policy --- keras_hub/src/models/smollm3/smollm3_backbone.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 31e587ceee..365ae8fab4 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -112,7 +112,7 @@ def __init__( self.rotary_embedding = RotaryEmbedding( max_wavelength=rope_theta, scaling_factor=rope_scaling, - dtype=self.dtype_policy + dtype=self.token_embedding.dtype_policy ) # === Functional Model === @@ -161,6 +161,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.partial_rotary_factor = partial_rotary_factor + self.rope_scaling = rope_scaling def get_config(self): config = super().get_config() @@ -181,6 +182,7 @@ def get_config(self): "max_position_embeddings": self.max_position_embeddings, "rope_theta": self.rope_theta, "partial_rotary_factor": self.partial_rotary_factor, + "rope_scaling": self.rope_scaling } ) return config From 4c4e1e0370faf162e28946afbeab7c18cadf63cc Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:46:30 +0900 Subject: [PATCH 39/76] pass dtype policy --- keras_hub/src/models/smollm3/smollm3_backbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 365ae8fab4..93f2c471be 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -70,7 +70,7 @@ def __init__( max_position_embeddings, rope_theta, partial_rotary_factor, - rope_scaling, + rope_scaling=1, **kwargs, ): # === Layers === From c8b7423922bfef9b9cb05af844ab30c59125e9b0 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 20:48:42 +0900 Subject: [PATCH 40/76] pass dtype policy --- keras_hub/src/utils/transformers/convert_smollm3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py index b8b08e9a26..23ab7c9210 100644 --- a/keras_hub/src/utils/transformers/convert_smollm3.py +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -29,8 +29,7 @@ def convert_backbone_config(transformers_config): "attention_dropout": transformers_config["attention_dropout"], "rope_layer_enabled_list": transformers_config["no_rope_layers"], "layer_types": transformers_config["layer_types"], - "mlp_bias": transformers_config["mlp_bias"], - "rope_scaling": transformers_config["rope_scaling"] + "mlp_bias": transformers_config["mlp_bias"] } From 8aebfd199a0d283b243cda8bce44d03d55954c47 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:02:05 +0900 Subject: [PATCH 41/76] refactor rotary embeddings --- .../src/models/smollm3/smollm3_backbone.py | 29 +++++++++---------- .../src/models/smollm3/smollm3_layers.py | 10 +++++-- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 93f2c471be..1d9c94850a 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -8,7 +8,6 @@ from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding -from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding @keras_hub_export( @@ -70,7 +69,6 @@ def __init__( max_position_embeddings, rope_theta, partial_rotary_factor, - rope_scaling=1, **kwargs, ): # === Layers === @@ -102,17 +100,12 @@ def __init__( name="sequence_output_layernorm", ) - #self.rotary_embedding = SmolLM3RotaryEmbedding( - # hidden_size=hidden_dim, - # num_attention_heads=num_attention_heads, - # max_position_embeddings=max_position_embeddings, - # rope_theta=rope_theta, - # partial_rotary_factor=partial_rotary_factor, - #) - self.rotary_embedding = RotaryEmbedding( - max_wavelength=rope_theta, - scaling_factor=rope_scaling, - dtype=self.token_embedding.dtype_policy + self.rotary_embedding = SmolLM3RotaryEmbedding( + hidden_size=hidden_dim, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + rope_theta=rope_theta, + partial_rotary_factor=partial_rotary_factor, ) # === Functional Model === @@ -124,8 +117,14 @@ def __init__( shape=(None,), dtype="int32", name="padding_mask" ) + cache_update_index = kwargs.get('self_attention_cache_index') + + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + hidden_states = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(hidden_states) + position_embeddings = self.rotary_embedding(hidden_states, start_index) for decoder_layer in self.transformer_layers[:num_layers]: hidden_states = decoder_layer( @@ -161,7 +160,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.partial_rotary_factor = partial_rotary_factor - self.rope_scaling = rope_scaling def get_config(self): config = super().get_config() @@ -182,7 +180,6 @@ def get_config(self): "max_position_embeddings": self.max_position_embeddings, "rope_theta": self.rope_theta, "partial_rotary_factor": self.partial_rotary_factor, - "rope_scaling": self.rope_scaling } ) return config diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 097e8161dd..45dc0494ce 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -583,7 +583,7 @@ def build(self, input_shape): def call( self, x, - position_ids, + start_index=0, ): """ Forward pass for SmolLM3RotaryEmbedding. @@ -596,13 +596,17 @@ def call( inv_freq_expanded = ops.expand_dims( ops.expand_dims(self.inv_freq, axis=0), axis=-1 ) + + batch_size = ops.shape(x)[0] + seq_len = ops.shape(x)[1] + positions = ops.arange(seq_len, dtype="float32") + positions + ops.cast(start_index, dtype="float32") - batch_size = ops.shape(position_ids)[0] inv_freq_expanded = ops.broadcast_to( inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) ) - position_ids_expanded = ops.expand_dims(position_ids, axis=1) + position_ids_expanded = ops.expand_dims(positions, axis=1) freqs = ops.matmul( ops.cast(inv_freq_expanded, "float32"), From 2c674dc00c20466024085c140fe5cf9561693a7d Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:03:33 +0900 Subject: [PATCH 42/76] refactor rotary embeddings --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index a3647cf2e9..4682665f8b 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -70,7 +70,7 @@ def call_with_cache( x = self.backbone.token_embedding(token_ids) # Each decoder layer has a cache; we update them separately. - position_embeddings = self.backbone.rotary_embedding(x) + position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) updated_cache = [] for i in range(self.backbone.num_layers): current_cache = cache[:, i, ...] From 06472bb811ec929a7c4d30c781ae312376120cc4 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:04:05 +0900 Subject: [PATCH 43/76] refactor rotary embeddings --- keras_hub/src/models/smollm3/smollm3_backbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 1d9c94850a..ea6e8f8358 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -124,7 +124,7 @@ def __init__( ) hidden_states = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(hidden_states, start_index) + position_embeddings = self.rotary_embedding(hidden_states, start_index=start_index) for decoder_layer in self.transformer_layers[:num_layers]: hidden_states = decoder_layer( From f913179e6b9111f9457c84a146b4d36c74583a42 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:06:32 +0900 Subject: [PATCH 44/76] refactor rotary embeddings --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 45dc0494ce..d108148e78 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -600,7 +600,7 @@ def call( batch_size = ops.shape(x)[0] seq_len = ops.shape(x)[1] positions = ops.arange(seq_len, dtype="float32") - positions + ops.cast(start_index, dtype="float32") + positions = positions + ops.cast(start_index, dtype="float32") inv_freq_expanded = ops.broadcast_to( inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) From a663a5c002dc81c97b8f2a97996c06f72b6bc1ab Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:13:45 +0900 Subject: [PATCH 45/76] refactor rotary embeddings --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index d108148e78..2491e78c6e 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -606,7 +606,7 @@ def call( inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) ) - position_ids_expanded = ops.expand_dims(positions, axis=1) + position_ids_expanded = ops.expand_dims(positions, axis=1).T freqs = ops.matmul( ops.cast(inv_freq_expanded, "float32"), From 630cc7011b66771f3b20d1ac90b936ee2aca99ec Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:50:13 +0900 Subject: [PATCH 46/76] log cache_update_index --- keras_hub/src/models/smollm3/smollm3_backbone.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index ea6e8f8358..5dabea5b22 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -118,6 +118,7 @@ def __init__( ) cache_update_index = kwargs.get('self_attention_cache_index') + print(cache_update_index) start_index = ( cache_update_index if cache_update_index is not None else 0 From de79b8d68c724b7420593dcb126689562d977dc8 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 21:58:59 +0900 Subject: [PATCH 47/76] rotary embed in loop --- keras_hub/src/models/smollm3/smollm3_backbone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 5dabea5b22..84534a113e 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -118,16 +118,16 @@ def __init__( ) cache_update_index = kwargs.get('self_attention_cache_index') - print(cache_update_index) start_index = ( cache_update_index if cache_update_index is not None else 0 ) hidden_states = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(hidden_states, start_index=start_index) + for decoder_layer in self.transformer_layers[:num_layers]: + position_embeddings = self.rotary_embedding(hidden_states, start_index=start_index) hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, From fc5974d0a897dd32cb942dcb18dfddf3bdfba5ff Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 22:04:31 +0900 Subject: [PATCH 48/76] log cache_update_index --- keras_hub/src/models/smollm3/smollm3_backbone.py | 3 +-- keras_hub/src/models/smollm3/smollm3_layers.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 84534a113e..ea6e8f8358 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -124,10 +124,9 @@ def __init__( ) hidden_states = self.token_embedding(token_id_input) - + position_embeddings = self.rotary_embedding(hidden_states, start_index=start_index) for decoder_layer in self.transformer_layers[:num_layers]: - position_embeddings = self.rotary_embedding(hidden_states, start_index=start_index) hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 2491e78c6e..bfb6131350 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -601,6 +601,7 @@ def call( seq_len = ops.shape(x)[1] positions = ops.arange(seq_len, dtype="float32") positions = positions + ops.cast(start_index, dtype="float32") + print(start_index) inv_freq_expanded = ops.broadcast_to( inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) From 35756368bbba111be551dbdb9a60636730dadd0d Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 5 Aug 2025 22:09:19 +0900 Subject: [PATCH 49/76] rotary embed in loop --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 4682665f8b..953318ffea 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -70,9 +70,10 @@ def call_with_cache( x = self.backbone.token_embedding(token_ids) # Each decoder layer has a cache; we update them separately. - position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) + updated_cache = [] for i in range(self.backbone.num_layers): + position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) current_cache = cache[:, i, ...] x, next_cache = self.backbone.transformer_layers[i]( x, From c71ea2e18c6dd123accf666b29cf6875494f6b14 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 22:57:18 +0900 Subject: [PATCH 50/76] small refactor --- keras_hub/src/models/smollm3/smollm3_backbone.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index ea6e8f8358..7b064bfbff 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -123,18 +123,18 @@ def __init__( cache_update_index if cache_update_index is not None else 0 ) - hidden_states = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(hidden_states, start_index=start_index) + x = self.token_embedding(token_id_input) + position_embeddings = self.rotary_embedding(x, start_index=start_index) - for decoder_layer in self.transformer_layers[:num_layers]: - hidden_states = decoder_layer( - hidden_states, + for decoder_layer in self.transformer_layers: + x = decoder_layer( + x, position_embeddings=position_embeddings, decoder_padding_mask=padding_mask_input, **kwargs, ) - sequence_output = self.norm(hidden_states) + sequence_output = self.norm(x) super().__init__( inputs={ "token_ids": token_id_input, From bb905f3ee8b43def984f73bd5300bdc9440b2ec3 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 22:58:58 +0900 Subject: [PATCH 51/76] add logging --- keras_hub/src/models/smollm3/smollm3_causal_lm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 953318ffea..8ab2dfc4d1 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -72,8 +72,10 @@ def call_with_cache( # Each decoder layer has a cache; we update them separately. updated_cache = [] + position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) + for i in range(self.backbone.num_layers): - position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) + print(f"Decoder layer {i}") current_cache = cache[:, i, ...] x, next_cache = self.backbone.transformer_layers[i]( x, @@ -81,6 +83,7 @@ def call_with_cache( self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) + print(x.shape, x) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.norm(x) From edbb757ff099f5fa129905c8bf5ec43de3d13e4c Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 23:34:57 +0900 Subject: [PATCH 52/76] more logging --- keras_hub/src/models/llama/llama_causal_lm.py | 1 + keras_hub/src/models/smollm3/smollm3_causal_lm.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/llama/llama_causal_lm.py b/keras_hub/src/models/llama/llama_causal_lm.py index 7f0f901d52..c6d490dc73 100644 --- a/keras_hub/src/models/llama/llama_causal_lm.py +++ b/keras_hub/src/models/llama/llama_causal_lm.py @@ -88,6 +88,7 @@ def call_with_cache( self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) + print(next_cache.shape) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.layer_norm(x) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 8ab2dfc4d1..c2491995e9 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -70,7 +70,6 @@ def call_with_cache( x = self.backbone.token_embedding(token_ids) # Each decoder layer has a cache; we update them separately. - updated_cache = [] position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) @@ -83,7 +82,7 @@ def call_with_cache( self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) - print(x.shape, x) + print(next_cache.shape) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.norm(x) From 6cf842249d8c6c3106a6f5a7cee0371a8a36630a Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 23:49:22 +0900 Subject: [PATCH 53/76] don't reshape unnecessarily in compute_kv --- keras_hub/src/models/llama/llama_causal_lm.py | 1 - keras_hub/src/models/qwen3/qwen3_causal_lm.py | 1 + .../src/models/smollm3/smollm3_causal_lm.py | 2 +- .../src/models/smollm3/smollm3_layers.py | 20 ++++++------------- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/keras_hub/src/models/llama/llama_causal_lm.py b/keras_hub/src/models/llama/llama_causal_lm.py index c6d490dc73..7f0f901d52 100644 --- a/keras_hub/src/models/llama/llama_causal_lm.py +++ b/keras_hub/src/models/llama/llama_causal_lm.py @@ -88,7 +88,6 @@ def call_with_cache( self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) - print(next_cache.shape) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.layer_norm(x) diff --git a/keras_hub/src/models/qwen3/qwen3_causal_lm.py b/keras_hub/src/models/qwen3/qwen3_causal_lm.py index f2d7b10b16..5d0cb60a58 100644 --- a/keras_hub/src/models/qwen3/qwen3_causal_lm.py +++ b/keras_hub/src/models/qwen3/qwen3_causal_lm.py @@ -193,6 +193,7 @@ def call_with_cache( self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) + #print(next_cache.shape) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.layer_norm(x) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index c2491995e9..57bd66f3ac 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -100,8 +100,8 @@ def _build_cache(self, token_ids): batch_size, num_layers, 2, - num_key_value_heads, max_length, + num_key_value_heads, head_dim, ] cache = ops.zeros(shape, dtype=self.compute_dtype) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index bfb6131350..55a9a4b31d 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -130,22 +130,14 @@ def call( query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) # (batch, num_heads, seq_len, head_dim) - query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) + #query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) def _compute_kv_values(x_input): - kv_hidden_shape = ( - *input_shape, - self.num_key_value_heads, - self.head_dim, - ) - - key_states_raw = ops.reshape(self.k_proj(x_input), kv_hidden_shape) - value_states_raw = ops.reshape( - self.v_proj(x_input), kv_hidden_shape - ) + key_states = self.k_proj(x_input) + value_states = self.v_proj(x_input) - key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3)) - value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3)) + #key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3)) + #value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3)) return key_states, value_states if self_attention_cache is not None: @@ -157,7 +149,7 @@ def _compute_kv_values(x_input): value_states = value_cache else: key_update, value_update = _compute_kv_values(hidden_states) - start = [0, 0, self_attention_cache_update_index, 0] + start = [0, self_attention_cache_update_index, 0, 0] key_states = ops.slice_update(key_cache, start, key_update) value_states = ops.slice_update( value_cache, start, value_update From 7b193e52dfb81fdb6b0a4ecd3aa66857c2026784 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 23:52:57 +0900 Subject: [PATCH 54/76] don't reshape unnecessarily in compute_kv --- keras_hub/src/models/smollm3/smollm3_layers.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 55a9a4b31d..d61120e85c 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -126,18 +126,12 @@ def call( ) input_shape = ops.shape(hidden_states)[:-1] - hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) - - query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) - # (batch, num_heads, seq_len, head_dim) - #query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) + query_states = self.q_proj(hidden_states) def _compute_kv_values(x_input): key_states = self.k_proj(x_input) value_states = self.v_proj(x_input) - #key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3)) - #value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3)) return key_states, value_states if self_attention_cache is not None: From 3ede850b59b8e84cc6d2f733cf0a87f6be41f66f Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 23:56:47 +0900 Subject: [PATCH 55/76] don't reshape unnecessarily in compute_kv --- .../src/models/smollm3/smollm3_layers.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index d61120e85c..f7355c7c31 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -126,12 +126,26 @@ def call( ) input_shape = ops.shape(hidden_states)[:-1] - query_states = self.q_proj(hidden_states) + hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) + + query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) + # (batch, num_heads, seq_len, head_dim) + query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) def _compute_kv_values(x_input): - key_states = self.k_proj(x_input) - value_states = self.v_proj(x_input) + kv_hidden_shape = ( + *input_shape, + self.num_key_value_heads, + self.head_dim, + ) + + key_states = ops.reshape(self.k_proj(x_input), kv_hidden_shape) + value_states = ops.reshape( + self.v_proj(x_input), kv_hidden_shape + ) + #key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3)) + #value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3)) return key_states, value_states if self_attention_cache is not None: From 327e2bf0bd8bb8386902386e70fd15a803e1a42a Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 16 Aug 2025 23:57:00 +0900 Subject: [PATCH 56/76] don't reshape unnecessarily in compute_kv --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index f7355c7c31..24743b074c 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -130,7 +130,7 @@ def call( query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) # (batch, num_heads, seq_len, head_dim) - query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) + #query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) def _compute_kv_values(x_input): kv_hidden_shape = ( From a798d359761508f5b8f997f76b4c54950250276e Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 00:06:24 +0900 Subject: [PATCH 57/76] don't reshape unnecessarily in compute_kv --- keras_hub/src/models/smollm3/smollm3_layers.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 24743b074c..0236a23650 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -51,6 +51,14 @@ def __init__( self.rope_layer_enabled_list = rope_layer_enabled_list self.layer_types = layer_types + self.rotary_embedding = SmolLM3RotaryEmbedding( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=65536, + rope_theta=5000000.0, + partial_rotary_factor=0.5, + ) + self.layer_idx = layer_idx self.head_dim = self.hidden_size // self.num_attention_heads @@ -124,6 +132,9 @@ def call( self_attention_cache_update_index = kwargs.get( "self_attention_cache_update_index", None ) + start_index = ( + self_attention_cache_update_index if self_attention_cache_update_index is not None else 0 + ) input_shape = ops.shape(hidden_states)[:-1] hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) @@ -175,10 +186,8 @@ def _compute_kv_values(x_input): key_states, value_states = _compute_kv_values(hidden_states) if self.use_rope: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) + query_states = self.rotary_embedding(query_states, start_index=start_index) + key_states = self.rotary_embedding(key_states, start_index=start_index) attn_output = eager_attention_forward( module=self, @@ -601,7 +610,6 @@ def call( seq_len = ops.shape(x)[1] positions = ops.arange(seq_len, dtype="float32") positions = positions + ops.cast(start_index, dtype="float32") - print(start_index) inv_freq_expanded = ops.broadcast_to( inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) From 9b5cd1158176d914ef2c1fcc460703e78cff1009 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 00:09:18 +0900 Subject: [PATCH 58/76] don't reshape unnecessarily in compute_kv --- keras_hub/src/models/smollm3/smollm3_layers.py | 16 +++++++++------- keras_hub/src/models/smollm3/smollm3_utils.py | 8 ++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 0236a23650..eb05c37920 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -139,7 +139,7 @@ def call( input_shape = ops.shape(hidden_states)[:-1] hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) - query_states = ops.reshape(self.q_proj(hidden_states), hidden_shape) + query = ops.reshape(self.q_proj(hidden_states), hidden_shape) # (batch, num_heads, seq_len, head_dim) #query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) @@ -186,14 +186,16 @@ def _compute_kv_values(x_input): key_states, value_states = _compute_kv_values(hidden_states) if self.use_rope: - query_states = self.rotary_embedding(query_states, start_index=start_index) - key_states = self.rotary_embedding(key_states, start_index=start_index) + query = self.rotary_embedding(query, start_index=start_index) + key = self.rotary_embedding(key_states, start_index=start_index) + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + attn_output = eager_attention_forward( - module=self, - query=query_states, - key=key_states, - value=value_states, + query=query, + key=key, + value=value, dropout=self.attention_dropout, scaling=self.scaling, training=self.training, diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 2bcd8d64f7..67093bbe45 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -30,7 +30,6 @@ def repeat_kv(hidden_states, n_rep): def eager_attention_forward( - module, query, key, value, @@ -45,11 +44,8 @@ def eager_attention_forward( name="attention_softmax", ) - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = ( - ops.matmul(query, ops.transpose(key_states, axes=(0, 1, 3, 2))) + ops.matmul(query, ops.transpose(key, axes=(0, 1, 3, 2))) * scaling ) @@ -60,7 +56,7 @@ def eager_attention_forward( if training: attn_weights = random.dropout(attn_weights, rate=dropout) - attn_output = ops.matmul(attn_weights, value_states) + attn_output = ops.matmul(attn_weights, value) attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3)) return attn_output From 308a68287c995c0dc48cf914f04a04618a92fe7b Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 00:11:26 +0900 Subject: [PATCH 59/76] don't reshape unnecessarily in compute_kv --- .../src/models/smollm3/smollm3_layers.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index eb05c37920..ea7ca01c4d 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -140,8 +140,6 @@ def call( hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) query = ops.reshape(self.q_proj(hidden_states), hidden_shape) - # (batch, num_heads, seq_len, head_dim) - #query_states = ops.transpose(query_states, axes=(0, 2, 1, 3)) def _compute_kv_values(x_input): kv_hidden_shape = ( @@ -150,31 +148,29 @@ def _compute_kv_values(x_input): self.head_dim, ) - key_states = ops.reshape(self.k_proj(x_input), kv_hidden_shape) - value_states = ops.reshape( + key = ops.reshape(self.k_proj(x_input), kv_hidden_shape) + value = ops.reshape( self.v_proj(x_input), kv_hidden_shape ) - #key_states = ops.transpose(key_states_raw, axes=(0, 2, 1, 3)) - #value_states = ops.transpose(value_states_raw, axes=(0, 2, 1, 3)) - return key_states, value_states + return key, value if self_attention_cache is not None: key_cache = self_attention_cache[:, 0, ...] value_cache = self_attention_cache[:, 1, ...] if self_attention_cache_update_index is None: - key_states = key_cache - value_states = value_cache + key = key_cache + value = value_cache else: key_update, value_update = _compute_kv_values(hidden_states) start = [0, self_attention_cache_update_index, 0, 0] - key_states = ops.slice_update(key_cache, start, key_update) - value_states = ops.slice_update( + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update( value_cache, start, value_update ) self_attention_cache = ops.stack( - (key_states, value_states), axis=1 + (key, value), axis=1 ) else: if self_attention_cache_update_index is not None: @@ -183,11 +179,11 @@ def _compute_kv_values(x_input): f"`None`. Received: self_attention_cache={self_attention_cache}, " f"self_attention_cache_update_index={self_attention_cache_update_index}" ) - key_states, value_states = _compute_kv_values(hidden_states) + key, value = _compute_kv_values(hidden_states) if self.use_rope: query = self.rotary_embedding(query, start_index=start_index) - key = self.rotary_embedding(key_states, start_index=start_index) + key = self.rotary_embedding(key, start_index=start_index) key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) From 6c65160d80aac62e2195701f88657cce3bf47729 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 00:14:33 +0900 Subject: [PATCH 60/76] don't reshape unnecessarily in compute_kv --- .../src/models/smollm3/smollm3_layers.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index ea7ca01c4d..4bcecd6924 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -385,30 +385,16 @@ def _compute_self_attention_mask( self_attention_cache, self_attention_cache_update_index, ): - """Computes the self-attention mask combining causal, padding and - attention masks. - - Args: - decoder_sequence: Input tensor. - decoder_padding_mask: Mask tensor for padding tokens. - decoder_attention_mask: Additional attention mask. - self_attention_cache: Optional cached key and value tensors. - self_attention_cache_update_index: Index at which to update the - cache. - - Returns: - Combined attention mask tensor. - """ decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask ) batch_size = ops.shape(decoder_sequence)[0] - output_length = ops.shape(decoder_sequence)[1] - input_length = output_length - + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. if self_attention_cache is not None: - # [batch, 2, num_heads, key_len, head_dim] - input_length = ops.shape(self_attention_cache)[3] + input_length = ops.shape(self_attention_cache)[2] cache_update_index = ( 0 From a56474cb8587248541b843aabf95cd9e1f4d534b Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:03:26 +0900 Subject: [PATCH 61/76] logging --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 4bcecd6924..408fd66898 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -185,8 +185,10 @@ def _compute_kv_values(x_input): query = self.rotary_embedding(query, start_index=start_index) key = self.rotary_embedding(key, start_index=start_index) + print('pre', key.shape, value.shape) key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + print('post', key.shape, value.shape) attn_output = eager_attention_forward( query=query, From 973b0e535ea38b14811d0cd3cf4e6d570de871c8 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:19:36 +0900 Subject: [PATCH 62/76] adjust how rope is applied --- keras_hub/src/models/smollm3/smollm3_layers.py | 9 ++++++--- keras_hub/src/models/smollm3/smollm3_utils.py | 7 +++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 408fd66898..4d78872770 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -9,7 +9,7 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( merge_padding_and_attention_mask, ) -from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb +from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb, apply_rotary_pos_single from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward from keras_hub.src.models.smollm3.smollm3_utils import rope_init @@ -182,8 +182,11 @@ def _compute_kv_values(x_input): key, value = _compute_kv_values(hidden_states) if self.use_rope: - query = self.rotary_embedding(query, start_index=start_index) - key = self.rotary_embedding(key, start_index=start_index) + query_cos, query_sin = self.rotary_embedding(query, start_index=start_index) + query = apply_rotary_pos_single(query, query_cos, query_sin) + + key_cos, key_sin = self.rotary_embedding(key, start_index=start_index) + key = apply_rotary_pos_single(key, key_cos, key_sin) print('pre', key.shape, value.shape) key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 67093bbe45..07b895883a 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -17,6 +17,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1): return q_embed, k_embed +def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1): + cos = ops.expand_dims(cos, expansion_axis) + sin = ops.expand_dims(sin, expansion_axis) + tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin) + return tensor_embed + + def repeat_kv(hidden_states, n_rep): batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states) if n_rep == 1: From b940997d029698fdc7b0102928644045cb417597 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:22:48 +0900 Subject: [PATCH 63/76] adjust how rope is applied --- keras_hub/src/models/smollm3/smollm3_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 07b895883a..53b831c18f 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -18,6 +18,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1): def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1): + print('tensor', tensor.shape) + print('cos', cos.shape) cos = ops.expand_dims(cos, expansion_axis) sin = ops.expand_dims(sin, expansion_axis) tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin) From 98daba9a8d94c1588a46a6318134d382df6d793d Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:35:13 +0900 Subject: [PATCH 64/76] switch to kerashub rotaryembedding --- keras_hub/src/models/smollm3/smollm3_layers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 4d78872770..52584db989 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -12,6 +12,7 @@ from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb, apply_rotary_pos_single from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward from keras_hub.src.models.smollm3.smollm3_utils import rope_init +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding class SmolLM3Attention(layers.Layer): @@ -51,12 +52,13 @@ def __init__( self.rope_layer_enabled_list = rope_layer_enabled_list self.layer_types = layer_types - self.rotary_embedding = SmolLM3RotaryEmbedding( + self.rotary_embedding = RotaryEmbedding( hidden_size=hidden_size, num_attention_heads=num_attention_heads, max_position_embeddings=65536, - rope_theta=5000000.0, + max_wavelength=5000000.0, partial_rotary_factor=0.5, + scaling_factor=1.0 ) self.layer_idx = layer_idx @@ -182,11 +184,8 @@ def _compute_kv_values(x_input): key, value = _compute_kv_values(hidden_states) if self.use_rope: - query_cos, query_sin = self.rotary_embedding(query, start_index=start_index) - query = apply_rotary_pos_single(query, query_cos, query_sin) - - key_cos, key_sin = self.rotary_embedding(key, start_index=start_index) - key = apply_rotary_pos_single(key, key_cos, key_sin) + query = self.rotary_embedding(query, start_index=start_index) + key = self.rotary_embedding(key, start_index=start_index) print('pre', key.shape, value.shape) key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) From 783f8d7ebe5061068769dbec04da9726c73d8883 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:37:53 +0900 Subject: [PATCH 65/76] switch to kerashub rotaryembedding --- keras_hub/src/models/smollm3/smollm3_layers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 52584db989..40ce00e156 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -53,12 +53,7 @@ def __init__( self.layer_types = layer_types self.rotary_embedding = RotaryEmbedding( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - max_position_embeddings=65536, max_wavelength=5000000.0, - partial_rotary_factor=0.5, - scaling_factor=1.0 ) self.layer_idx = layer_idx From 97aea007ec57429303e711e1fb2cc0b8fe9a1e03 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:42:47 +0900 Subject: [PATCH 66/76] remove reshape --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 40ce00e156..e0e35fe9e2 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -197,7 +197,7 @@ def _compute_kv_values(x_input): attention_mask=attention_mask, ) - attn_output = ops.reshape(attn_output, (*input_shape, self.hidden_size)) + #attn_output = ops.reshape(attn_output, (*input_shape, self.hidden_size)) attn_output = self.o_proj(attn_output) From c31c889a064908222ef03d2bdd56d61fcc3cf2d6 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:51:54 +0900 Subject: [PATCH 67/76] fix reshape --- keras_hub/src/models/smollm3/smollm3_layers.py | 4 +++- keras_hub/src/models/smollm3/smollm3_utils.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index e0e35fe9e2..c0e56276a0 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -197,10 +197,12 @@ def _compute_kv_values(x_input): attention_mask=attention_mask, ) - #attn_output = ops.reshape(attn_output, (*input_shape, self.hidden_size)) + attn_output = ops.reshape(attn_output, (*input_shape, -1)) attn_output = self.o_proj(attn_output) + + if self_attention_cache is not None: return attn_output, self_attention_cache diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 53b831c18f..07b895883a 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -18,8 +18,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1): def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1): - print('tensor', tensor.shape) - print('cos', cos.shape) cos = ops.expand_dims(cos, expansion_axis) sin = ops.expand_dims(sin, expansion_axis) tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin) From 3aaa3e15ab9171bf8f388c505c109704d7b72b15 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:57:02 +0900 Subject: [PATCH 68/76] new attention computation --- .../src/models/smollm3/smollm3_layers.py | 67 ++++++++++++++++--- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index c0e56276a0..1982b6187b 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -187,17 +187,15 @@ def _compute_kv_values(x_input): value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) print('post', key.shape, value.shape) - attn_output = eager_attention_forward( - query=query, - key=key, - value=value, - dropout=self.attention_dropout, - scaling=self.scaling, - training=self.training, - attention_mask=attention_mask, + attn_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=self_attention_cache_update_index, ) - attn_output = ops.reshape(attn_output, (*input_shape, -1)) + #attn_output = ops.reshape(attn_output, (*input_shape, -1)) attn_output = self.o_proj(attn_output) @@ -237,6 +235,57 @@ def compute_output_shape(self, input_shape): ) return [output_attn_output_shape, output_attn_weights_shape] + + + + def _masked_softmax(self, attention_scores, attention_mask=None): + """Applies softmax with optional masking. + + Args: + attention_scores: Attention score tensor. + attention_mask: Optional mask tensor. + + Returns: + Masked softmax attention weights. + """ + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, query, key, value, attention_mask=None, cache_update_index=None + ): + """Computes attention using query, key, and value tensors. + + Uses Flash Attention when available for better performance. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + attention_mask: Optional mask tensor. + cache_update_index: Index for sliding window computation. + + Returns: + attention_output: Output tensor after applying attention. + """ + attention_scores = ops.einsum(self._dot_product_equation, query, key) + + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output class SmolLM3MLP(layers.Layer): From 46eed1a874d8100a7f4ceb0ee9df84da357a8c8f Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 13:59:01 +0900 Subject: [PATCH 69/76] new attention computation --- keras_hub/src/models/smollm3/smollm3_layers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 1982b6187b..5da639e966 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -51,6 +51,8 @@ def __init__( self.attention_dropout = attention_dropout self.rope_layer_enabled_list = rope_layer_enabled_list self.layer_types = layer_types + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" self.rotary_embedding = RotaryEmbedding( max_wavelength=5000000.0, @@ -195,12 +197,8 @@ def _compute_kv_values(x_input): cache_update_index=self_attention_cache_update_index, ) - #attn_output = ops.reshape(attn_output, (*input_shape, -1)) - attn_output = self.o_proj(attn_output) - - if self_attention_cache is not None: return attn_output, self_attention_cache From 72fabf49f153ff19b5ed8e127924281f4660fd4c Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:00:57 +0900 Subject: [PATCH 70/76] new attention computation --- keras_hub/src/models/smollm3/smollm3_layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 5da639e966..8fcc64577d 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -13,6 +13,7 @@ from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward from keras_hub.src.models.smollm3.smollm3_utils import rope_init from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +import math class SmolLM3Attention(layers.Layer): @@ -54,6 +55,9 @@ def __init__( self._dot_product_equation = "bquh,bkuh->buqk" self._combine_equation = "buqk,bkuh->bquh" + self.head_dim = hidden_size // self.num_attention_heads + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.rotary_embedding = RotaryEmbedding( max_wavelength=5000000.0, ) From 66170cb8fb9d793424bc7ab8b701f1fa2e3acf79 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:02:51 +0900 Subject: [PATCH 71/76] new attention computation --- keras_hub/src/models/smollm3/smollm3_layers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 8fcc64577d..bc97324029 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -96,6 +96,12 @@ def __init__( else True ) # Default to True if index out of bounds + self._softmax = layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + def build(self, input_shape): """ Builds the internal Dense layers. From ceab147f7d73f9a771c4f060d11db0081c4df5ed Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:07:53 +0900 Subject: [PATCH 72/76] new attention computation --- keras_hub/src/models/smollm3/smollm3_layers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index bc97324029..4ad1a73728 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -86,9 +86,17 @@ def __init__( use_bias=self.attention_bias, name="v_proj", ) - self.o_proj = layers.Dense( - self.hidden_size, use_bias=self.attention_bias, name="o_proj" - ) + #self.o_proj = layers.Dense( + # self.hidden_size, use_bias=self.attention_bias, name="o_proj" + #) + self.o_proj = layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, self.hidden_size), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.o_proj.build((None, None, self.num_attention_heads, self.head_dim)) self.use_rope = ( self.rope_layer_enabled_list[self.layer_idx] From 0ded2ae22f7ff0869aaa0cd1a94d9853344694f2 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:08:35 +0900 Subject: [PATCH 73/76] new attention computation --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 4ad1a73728..b30933b775 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -92,8 +92,6 @@ def __init__( self.o_proj = layers.EinsumDense( equation="bquh,uhm->bqm", output_shape=(None, self.hidden_size), - kernel_initializer=self.kernel_initializer, - dtype=self.dtype_policy, name="attention_output", ) self.o_proj.build((None, None, self.num_attention_heads, self.head_dim)) From 93bc8e84508d925d0736b18d2f0c53ba2a51d13c Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:09:37 +0900 Subject: [PATCH 74/76] new attention computation --- keras_hub/src/models/smollm3/smollm3_layers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index b30933b775..f130a17884 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -122,7 +122,6 @@ def build(self, input_shape): self.q_proj.build(hidden_states_shape) self.k_proj.build(hidden_states_shape) self.v_proj.build(hidden_states_shape) - self.o_proj.build(hidden_states_shape) super().build(input_shape) def call( From af85773bd2751da013abe69cd2736cbcb02f7ae2 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:15:26 +0900 Subject: [PATCH 75/76] slight cleanup --- .../src/models/smollm3/smollm3_backbone.py | 18 +--------- .../src/models/smollm3/smollm3_causal_lm.py | 4 --- .../src/models/smollm3/smollm3_layers.py | 6 ---- keras_hub/src/models/smollm3/smollm3_utils.py | 33 ------------------- 4 files changed, 1 insertion(+), 60 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py index 7b064bfbff..34de272091 100644 --- a/keras_hub/src/models/smollm3/smollm3_backbone.py +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -99,15 +99,7 @@ def __init__( epsilon=layer_norm_epsilon, name="sequence_output_layernorm", ) - - self.rotary_embedding = SmolLM3RotaryEmbedding( - hidden_size=hidden_dim, - num_attention_heads=num_attention_heads, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - partial_rotary_factor=partial_rotary_factor, - ) - + # === Functional Model === token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" @@ -117,19 +109,11 @@ def __init__( shape=(None,), dtype="int32", name="padding_mask" ) - cache_update_index = kwargs.get('self_attention_cache_index') - - start_index = ( - cache_update_index if cache_update_index is not None else 0 - ) - x = self.token_embedding(token_id_input) - position_embeddings = self.rotary_embedding(x, start_index=start_index) for decoder_layer in self.transformer_layers: x = decoder_layer( x, - position_embeddings=position_embeddings, decoder_padding_mask=padding_mask_input, **kwargs, ) diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py index 57bd66f3ac..4fbeec477f 100644 --- a/keras_hub/src/models/smollm3/smollm3_causal_lm.py +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -71,18 +71,14 @@ def call_with_cache( # Each decoder layer has a cache; we update them separately. updated_cache = [] - position_embeddings = self.backbone.rotary_embedding(x, start_index=cache_update_index) for i in range(self.backbone.num_layers): - print(f"Decoder layer {i}") current_cache = cache[:, i, ...] x, next_cache = self.backbone.transformer_layers[i]( x, - position_embeddings=position_embeddings, self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) - print(next_cache.shape) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.norm(x) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index f130a17884..37f6507dd3 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -86,9 +86,6 @@ def __init__( use_bias=self.attention_bias, name="v_proj", ) - #self.o_proj = layers.Dense( - # self.hidden_size, use_bias=self.attention_bias, name="o_proj" - #) self.o_proj = layers.EinsumDense( equation="bquh,uhm->bqm", output_shape=(None, self.hidden_size), @@ -127,7 +124,6 @@ def build(self, input_shape): def call( self, hidden_states, - position_embeddings, training=False, attention_mask=None, **kwargs, @@ -508,7 +504,6 @@ def build(self, input_shape): def call( self, hidden_states, - position_embeddings=None, training=False, decoder_padding_mask=None, decoder_attention_mask=None, @@ -541,7 +536,6 @@ def call( # Self Attention x = self.self_attn( hidden_states=hidden_states, - position_embeddings=position_embeddings, training=training, attention_mask=self_attention_mask, **kwargs, diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py index 07b895883a..8fb057f363 100644 --- a/keras_hub/src/models/smollm3/smollm3_utils.py +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -36,39 +36,6 @@ def repeat_kv(hidden_states, n_rep): ) -def eager_attention_forward( - query, - key, - value, - scaling, - attention_mask=None, - dropout=0.0, - training=False, -): - softmax_op = layers.Softmax( - axis=-1, - dtype="float32", - name="attention_softmax", - ) - - attn_weights = ( - ops.matmul(query, ops.transpose(key, axes=(0, 1, 3, 2))) - * scaling - ) - - if attention_mask is not None: - attn_weights = softmax_op(attn_weights, attention_mask[:, None, :, :]) - else: - attn_weights = softmax_op(attn_weights) - - if training: - attn_weights = random.dropout(attn_weights, rate=dropout) - attn_output = ops.matmul(attn_weights, value) - attn_output = ops.transpose(attn_output, axes=(0, 2, 1, 3)) - - return attn_output - - def rope_init(rope_theta: float, partial_rotary_factor: float, head_dim: int): base = rope_theta dim = int(head_dim * partial_rotary_factor) From f66846b6865f069b13c3286b00b514b7821e0763 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 17 Aug 2025 14:16:30 +0900 Subject: [PATCH 76/76] slight cleanup --- keras_hub/src/models/smollm3/smollm3_layers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py index 37f6507dd3..811aaa9dac 100644 --- a/keras_hub/src/models/smollm3/smollm3_layers.py +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -9,8 +9,6 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( merge_padding_and_attention_mask, ) -from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb, apply_rotary_pos_single -from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward from keras_hub.src.models.smollm3.smollm3_utils import rope_init from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding import math