Skip to content

Commit 5b6e62a

Browse files
feat: update backbone + add tokenizer
1 parent 6bed56a commit 5b6e62a

File tree

5 files changed

+303
-33
lines changed

5 files changed

+303
-33
lines changed

keras_hub/src/models/modernbert/modernbert_backbone.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,15 @@
55
ReversibleEmbedding,
66
)
77
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8-
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
98
from keras_hub.src.models.backbone import Backbone
10-
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9+
from keras_hub.src.models.modernbert.modernbert_layers import (
10+
ModernBERTEncoderLayer,
11+
)
1112
from keras_hub.src.utils.keras_utils import gelu_approximate
1213

1314

1415
@keras_hub_export("keras_hub.models.ModernBertBackbone")
1516
class ModernBertBackbone(Backbone):
16-
"""A ModernBERT encoder network.
17-
18-
This class implements the ModernBERT backbone, using rotary embeddings,
19-
RMS normalization, and a stack of TransformerEncoder layers.
20-
"""
21-
2217
def __init__(
2318
self,
2419
vocabulary_size,
@@ -45,37 +40,33 @@ def __init__(
4540
)
4641
self.position_embedding = RotaryEmbedding(
4742
max_wavelength=rotary_max_wavelength,
48-
sequence_axis=1,
49-
feature_axis=-1,
5043
dtype=dtype,
5144
name="rotary_embedding",
5245
)
53-
self.embeddings_layer_norm = RMSNormalization(
54-
dtype=dtype,
46+
self.embeddings_layer_norm = keras.layers.LayerNormalization(
5547
epsilon=layer_norm_epsilon,
56-
)
57-
self.embeddings_dropout = keras.layers.Dropout(
58-
dropout, dtype=dtype, name="embeddings_dropout"
48+
dtype=dtype,
49+
rms_scaling=True,
50+
name="embeddings_layer_norm",
5951
)
6052
self.transformer_layers = []
6153
for i in range(num_layers):
62-
layer = TransformerEncoder(
54+
layer = ModernBERTEncoderLayer(
55+
hidden_size=hidden_dim,
56+
intermediate_size=intermediate_dim,
6357
num_heads=num_heads,
64-
intermediate_dim=intermediate_dim,
6558
activation=gelu_approximate,
66-
dropout=dropout,
6759
layer_norm_epsilon=layer_norm_epsilon,
68-
kernel_initializer=keras.initializers.TruncatedNormal(
69-
stddev=0.02
70-
),
60+
rotary_embedding=self.position_embedding,
7161
dtype=dtype,
7262
name=f"transformer_layer_{i}",
7363
)
7464
self.transformer_layers.append(layer)
75-
self.final_norm = RMSNormalization(
76-
dtype=dtype,
65+
self.final_norm = keras.layers.LayerNormalization(
7766
epsilon=layer_norm_epsilon,
78-
name="final_normalization",
67+
rms_scaling=True,
68+
dtype=dtype,
69+
name="final_layernorm",
7970
)
8071

8172
# === Functional Model ===
@@ -85,20 +76,13 @@ def __init__(
8576
padding_mask_input = keras.Input(
8677
shape=(None,), dtype="int32", name="padding_mask"
8778
)
88-
89-
# Embed tokens and apply rotary position embedding
9079
x = self.token_embedding(token_id_input)
91-
x = self.position_embedding(x)
9280
x = self.embeddings_layer_norm(x)
93-
x = self.embeddings_dropout(x)
94-
95-
# Transformer layers
9681
for transformer_layer in self.transformer_layers:
97-
x = transformer_layer(x, padding_mask=padding_mask_input)
98-
99-
# Final normalization
82+
x = transformer_layer(x)
10083
sequence_output = self.final_norm(x)
10184

85+
# Instantiate using Functional API Model constructor
10286
super().__init__(
10387
inputs={
10488
"token_ids": token_id_input,
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import keras
2+
from keras import layers
3+
from keras import ops
4+
5+
from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
6+
from keras_hub.src.models.flux.flux_maths import scaled_dot_product_attention
7+
8+
9+
class MLP(keras.layers.Layer):
10+
def __init__(
11+
self,
12+
hidden_size,
13+
intermediate_size,
14+
activation="gelu",
15+
dtype=None,
16+
**kwargs,
17+
):
18+
super(MLP, self).__init__(**kwargs)
19+
self.Wi = layers.Dense(
20+
intermediate_size * 2,
21+
use_bias=False,
22+
dtype=dtype,
23+
)
24+
self.act = keras.activations.get(activation)
25+
self.Wo = layers.Dense(
26+
hidden_size,
27+
use_bias=False,
28+
dtype=dtype,
29+
)
30+
31+
def call(self, x):
32+
input, gate = ops.split(self.Wi(x), 2, axis=-1)
33+
return self.Wo(self.act(input) * gate)
34+
35+
36+
class ModernBERTAttention(keras.Model):
37+
def __init__(
38+
self, hidden_size, num_heads, rotary_embedding, dtype=None, **kwargs
39+
):
40+
super(ModernBERTAttention, self).__init__(**kwargs)
41+
self.num_heads = num_heads
42+
self.hidden_size = hidden_size
43+
self.rotary_embedding = rotary_embedding
44+
self.Wqkv = layers.Dense(hidden_size * 3, use_bias=False, dtype=dtype)
45+
self.Wo = layers.Dense(hidden_size, use_bias=False, dtype=dtype)
46+
47+
def build(self, input_shape):
48+
self.Wqkv.build(input_shape)
49+
self.Wo.build((None, input_shape[1], input_shape[-1]))
50+
51+
def call(self, x):
52+
qkv = self.Wqkv(x)
53+
q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
54+
55+
# Apply rotary embeddings
56+
q = self.rotary_embedding(q)
57+
k = self.rotary_embedding(k)
58+
59+
# Apply scaled dot product attention
60+
x = scaled_dot_product_attention(q, k, v)
61+
62+
# Reshape and apply final dense layer
63+
x = ops.transpose(x, (0, 2, 1, 3))
64+
b, s, h, d = ops.shape(x)
65+
x = ops.reshape(x, (b, s, h * d))
66+
x = self.Wo(x)
67+
return x
68+
69+
70+
class ModernBERTEncoderLayer(keras.Model):
71+
def __init__(
72+
self,
73+
hidden_size,
74+
intermediate_size,
75+
num_heads,
76+
activation="gelu",
77+
layer_norm_epsilon=1e-05,
78+
rotary_embedding=None,
79+
dtype=None,
80+
**kwargs,
81+
):
82+
super(ModernBERTEncoderLayer, self).__init__(**kwargs)
83+
self.attn = ModernBERTAttention(
84+
hidden_size, num_heads, rotary_embedding, dtype=dtype
85+
)
86+
self.mlp_norm = layers.LayerNormalization(
87+
epsilon=layer_norm_epsilon, dtype=dtype
88+
)
89+
self.mlp = MLP(hidden_size, intermediate_size, activation, dtype=dtype)
90+
91+
def call(self, x):
92+
x = self.attn(x)
93+
x = self.mlp_norm(x)
94+
x = self.mlp(x)
95+
return x
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.modernbert.modernbert_backbone import (
3+
ModernBertBackbone,
4+
)
5+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
6+
7+
8+
@keras_hub_export(
9+
[
10+
"keras_hub.tokenizers.ModernBertTokenizer",
11+
"keras_hub.models.ModernBertTokenizer",
12+
]
13+
)
14+
class ModernBertTokenizer(BytePairTokenizer):
15+
backbone_cls = ModernBertBackbone
16+
17+
def __init__(
18+
self,
19+
vocabulary=None,
20+
merges=None,
21+
**kwargs,
22+
):
23+
self._add_special_token("[CLS]", "cls_token")
24+
self._add_special_token("[SEP]", "sep_token")
25+
self._add_special_token("[PAD]", "pad_token")
26+
self._add_special_token("[UNK]", "unk_token")
27+
self._add_special_token("[MASK]", "mask_token")
28+
# Also add `tokenizer.start_token` and `tokenizer.end_token` for
29+
# compatibility with other tokenizers.
30+
self._add_special_token("[CLS]", "start_token")
31+
self._add_special_token("[SEP]", "end_token")
32+
super().__init__(
33+
vocabulary=vocabulary,
34+
merges=merges,
35+
**kwargs,
36+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from keras_hub.src.models.modernbert.modernbert_tokenizer import (
2+
ModernBertTokenizer,
3+
)
4+
from keras_hub.src.tests.test_case import TestCase
5+
6+
7+
class ModernBertTokenizerTest(TestCase):
8+
def setUp(self):
9+
self.vocab = ["[CLS]", "[PAD]", "[SEP]", "air", "Ġair", "plane", "Ġat"]
10+
self.vocab += ["port", "[MASK]", "[UNK]"]
11+
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
12+
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
13+
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
14+
self.merges += ["Ġai r", "Ġa i", "pla ne"]
15+
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
16+
self.input_data = [
17+
"[CLS] airplane at airport[SEP][PAD]",
18+
" airplane airport",
19+
]
20+
21+
def test_tokenizer_basics(self):
22+
self.run_preprocessing_layer_test(
23+
cls=ModernBertTokenizer,
24+
init_kwargs=self.init_kwargs,
25+
input_data=self.input_data,
26+
expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]],
27+
expected_detokenize_output=[
28+
"[CLS] airplane at airport[SEP][PAD]",
29+
" airplane airport",
30+
],
31+
)
32+
33+
def test_errors_missing_special_tokens(self):
34+
with self.assertRaises(ValueError):
35+
ModernBertTokenizer(vocabulary=["a", "b", "c"], merges=[])
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Convert ModernBERT checkpoints.
2+
3+
python tools/checkpoint_conversion/convert_modernbert_checkpoints.py \
4+
--preset modernbert_base
5+
python tools/checkpoint_conversion/convert_modernbert_checkpoints.py \
6+
--preset modernbert_large
7+
"""
8+
9+
import json
10+
import os
11+
12+
import numpy as np
13+
import requests
14+
import transformers
15+
from absl import app
16+
from absl import flags
17+
18+
from keras_hub.src.models.modernbert.modernbert_backbone import (
19+
ModernBertBackbone,
20+
)
21+
22+
PRESET_MAP = {
23+
"modernbert_base": "answerdotai/ModernBERT-base",
24+
"modernbert_large": "answerdotai/ModernBERT-large",
25+
}
26+
27+
EXTRACT_DIR = "./{}"
28+
29+
FLAGS = flags.FLAGS
30+
flags.DEFINE_string(
31+
"preset",
32+
None,
33+
f"Must be one of {','.join(PRESET_MAP.keys())}",
34+
)
35+
36+
37+
def download_files(hf_model_name):
38+
extract_dir = EXTRACT_DIR.format(FLAGS.preset)
39+
if not os.path.exists(extract_dir):
40+
os.makedirs(extract_dir)
41+
42+
# Config.
43+
config_path = os.path.join(extract_dir, "config.json")
44+
response = requests.get(
45+
f"https://huggingface.co/{hf_model_name}/raw/main/config.json"
46+
)
47+
open(config_path, "wb").write(response.content)
48+
49+
50+
def convert_model(hf_model):
51+
extract_dir = EXTRACT_DIR.format(FLAGS.preset)
52+
config_path = os.path.join(extract_dir, "config.json")
53+
54+
# Build config.
55+
cfg = {}
56+
with open(config_path, "r") as pt_cfg_handler:
57+
pt_cfg = json.load(pt_cfg_handler)
58+
cfg["vocabulary_size"] = pt_cfg["vocab_size"]
59+
cfg["num_layers"] = pt_cfg["num_hidden_layers"]
60+
cfg["num_heads"] = pt_cfg["num_attention_heads"]
61+
cfg["hidden_dim"] = pt_cfg["hidden_size"]
62+
cfg["intermediate_dim"] = pt_cfg["intermediate_size"]
63+
cfg["dropout"] = pt_cfg["embedding_dropout"]
64+
cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"]
65+
66+
return ModernBertBackbone(**cfg)
67+
68+
69+
def convert_weights(keras_model, hf_model):
70+
# Get `state_dict` from `hf_model`.
71+
state_dict = hf_model.state_dict()
72+
73+
keras_model.get_layer("token_embedding").set_weights(
74+
[np.asarray(state_dict["embeddings.tok_embeddings.weight"])]
75+
)
76+
77+
keras_model.get_layer("embeddings_layer_norm").set_weights(
78+
[np.asarray(state_dict["embeddings.norm.weight"])]
79+
)
80+
81+
for i in range(keras_model.num_layers):
82+
keras_model.transformer_layers[i].attn.Wqkv.kernel.assign(
83+
state_dict[f"layers.{i}.attn.Wqkv.weight"].T
84+
)
85+
keras_model.transformer_layers[i].attn.Wo.kernel.assign(
86+
state_dict[f"layers.{i}.attn.Wo.weight"]
87+
)
88+
keras_model.transformer_layers[i].mlp_norm.gamma.assign(
89+
state_dict[f"layers.{i}.mlp_norm.weight"]
90+
)
91+
keras_model.transformer_layers[i].mlp.Wi.kernel.assign(
92+
state_dict[f"layers.{i}.mlp.Wi.weight"].T
93+
)
94+
keras_model.transformer_layers[i].mlp.Wo.kernel.assign(
95+
state_dict[f"layers.{i}.mlp.Wo.weight"].T
96+
)
97+
98+
keras_model.get_layer("final_layernorm").set_weights(
99+
[np.asarray(state_dict["final_norm.weight"])]
100+
)
101+
102+
103+
def main(_):
104+
hf_model_name = PRESET_MAP[FLAGS.preset]
105+
download_files(hf_model_name)
106+
107+
hf_model = transformers.AutoModel.from_pretrained(hf_model_name)
108+
hf_model.eval()
109+
110+
print(f"🏃 Coverting {FLAGS.preset}")
111+
keras_model = convert_model(hf_model)
112+
print("✅ KerasHub model loaded.")
113+
114+
convert_weights(keras_model, hf_model)
115+
print("✅ Weights converted.")
116+
117+
118+
if __name__ == "__main__":
119+
flags.mark_flag_as_required("preset")
120+
app.run(main)

0 commit comments

Comments
 (0)