Skip to content

Commit 26511b2

Browse files
committed
add causal attn mask, a few fixes
1 parent e126938 commit 26511b2

File tree

9 files changed

+677
-58
lines changed

9 files changed

+677
-58
lines changed

keras_hub/api/models/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,30 @@
576576
from keras_hub.src.models.siglip.siglip_vision_encoder import (
577577
SigLIPVisionEncoder as SigLIPVisionEncoder,
578578
)
579+
from keras_hub.src.models.smollm3.smollm3_backbone import (
580+
SmolLM3Backbone as SmolLM3Backbone,
581+
)
582+
from keras_hub.src.models.smollm3.smollm3_backbone import (
583+
SmolLM3Backbone as SmolLMBackbone,
584+
)
585+
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
586+
SmolLM3CausalLM as SmolLM3CausalLM,
587+
)
588+
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
589+
SmolLM3CausalLM as SmolLMCausalLM,
590+
)
591+
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
592+
SmolLM3CausalLMPreprocessor as SmolLM3CausalLMPreprocessor,
593+
)
594+
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
595+
SmolLM3CausalLMPreprocessor as SmolLMCausalLMPreprocessor,
596+
)
597+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
598+
SmolLM3Tokenizer as SmolLM3Tokenizer,
599+
)
600+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
601+
SmolLM3Tokenizer as SmolLMTokenizer,
602+
)
579603
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
580604
StableDiffusion3Backbone as StableDiffusion3Backbone,
581605
)

keras_hub/api/tokenizers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@
8686
from keras_hub.src.models.siglip.siglip_tokenizer import (
8787
SigLIPTokenizer as SigLIPTokenizer,
8888
)
89+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
90+
SmolLM3Tokenizer as SmolLM3Tokenizer,
91+
)
92+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
93+
SmolLM3Tokenizer as SmolLMTokenizer,
94+
)
8995
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
9096
from keras_hub.src.models.whisper.whisper_tokenizer import (
9197
WhisperTokenizer as WhisperTokenizer,

keras_hub/src/models/smollm3/smollm3_backbone.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import keras
22

33
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.layers.modeling.reversible_embedding import (
5+
ReversibleEmbedding,
6+
)
47
from keras_hub.src.models.backbone import Backbone
58
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
69
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding
@@ -68,12 +71,12 @@ def __init__(
6871
**kwargs,
6972
):
7073
# === Layers ===
71-
self.token_embedding = keras.layers.Embedding(
74+
self.token_embedding = ReversibleEmbedding(
7275
input_dim=vocabulary_size,
7376
output_dim=hidden_dim,
7477
name="token_embedding",
7578
)
76-
self.decoder_layers = []
79+
self.transformer_layers = []
7780

7881
for i in range(num_layers):
7982
layer = SmolLM3DecoderLayer(
@@ -87,10 +90,10 @@ def __init__(
8790
layer_idx=i,
8891
intermediate_size=intermediate_dim,
8992
mlp_bias=mlp_bias,
90-
rms_norm_epsilon=layer_norm_epsilon,
93+
layer_norm_epsilon=layer_norm_epsilon,
9194
name=f"transformer_layer_{i}",
9295
)
93-
self.decoder_layers.append(layer)
96+
self.transformer_layers.append(layer)
9497

9598
self.norm = keras.layers.RMSNormalization(
9699
epsilon=layer_norm_epsilon,
@@ -112,16 +115,20 @@ def __init__(
112115
position_id_input = keras.Input(
113116
shape=(None,), dtype="int32", name="position_ids"
114117
)
118+
padding_mask_input = keras.Input(
119+
shape=(None,), dtype="int32", name="padding_mask"
120+
)
115121

116122
hidden_states = self.token_embedding(token_id_input)
117123
position_embeddings = self.rotary_embedding(
118124
hidden_states, position_id_input
119125
)
120126

121-
for decoder_layer in self.decoder_layers[:num_layers]:
127+
for decoder_layer in self.transformer_layers[:num_layers]:
122128
hidden_states = decoder_layer(
123129
hidden_states,
124130
position_embeddings=position_embeddings,
131+
decoder_padding_mask=padding_mask_input,
125132
**kwargs,
126133
)
127134

@@ -130,21 +137,48 @@ def __init__(
130137
inputs={
131138
"token_ids": token_id_input,
132139
"position_ids": position_id_input,
140+
"padding_mask": padding_mask_input,
133141
},
134142
outputs=sequence_output,
135143
**kwargs,
136144
)
137145

138146
# === Config ===
139147
self.vocabulary_size = vocabulary_size
148+
self.hidden_dim = hidden_dim
149+
self.intermediate_dim = intermediate_dim
140150
self.num_layers = num_layers
151+
self.num_attention_heads = num_attention_heads
152+
self.num_key_value_heads = num_key_value_heads
153+
self.attention_bias = attention_bias
154+
self.attention_dropout = attention_dropout
155+
self.rope_layer_enabled_list = rope_layer_enabled_list
156+
self.layer_types = layer_types
157+
self.mlp_bias = mlp_bias
158+
self.layer_norm_epsilon = layer_norm_epsilon
159+
self.max_position_embeddings = max_position_embeddings
160+
self.rope_theta = rope_theta
161+
self.partial_rotary_factor = partial_rotary_factor
141162

142163
def get_config(self):
143164
config = super().get_config()
144165
config.update(
145166
{
146167
"vocabulary_size": self.vocabulary_size,
168+
"hidden_dim": self.hidden_dim,
169+
"intermediate_dim": self.intermediate_dim,
147170
"num_layers": self.num_layers,
171+
"num_attention_heads": self.num_attention_heads,
172+
"num_key_value_heads": self.num_key_value_heads,
173+
"attention_bias": self.attention_bias,
174+
"attention_dropout": self.attention_dropout,
175+
"rope_layer_enabled_list": self.rope_layer_enabled_list,
176+
"layer_types": self.layer_types,
177+
"mlp_bias": self.mlp_bias,
178+
"layer_norm_epsilon": self.layer_norm_epsilon,
179+
"max_position_embeddings": self.max_position_embeddings,
180+
"rope_theta": self.rope_theta,
181+
"partial_rotary_factor": self.partial_rotary_factor,
148182
}
149183
)
150184
return config

0 commit comments

Comments
 (0)