Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@
from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor,
)
from keras_hub.src.models.modernbert.modernbert_backbone import (
ModernBertBackbone as ModernBertBackbone,
)
from keras_hub.src.models.modernbert.modernbert_tokenizer import (
ModernBertTokenizer as ModernBertTokenizer,
)
from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
MoonshineAudioToText as MoonshineAudioToText,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from keras_hub.src.models.mixtral.mixtral_tokenizer import (
MixtralTokenizer as MixtralTokenizer,
)
from keras_hub.src.models.modernbert.modernbert_tokenizer import (
ModernBertTokenizer as ModernBertTokenizer,
)
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
MoonshineTokenizer as MoonshineTokenizer,
)
Expand Down
Empty file.
155 changes: 155 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import keras
from keras import layers
from keras import ops
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.backbone import Backbone

from keras_hub.src.models.modernbert.modernbert_layers import (
ModernBertEncoderLayer,
)

from keras_hub.src.utils.keras_utils import gelu_approximate

@keras_hub_export("keras_hub.models.ModernBertBackbone")
class ModernBertBackbone(Backbone):
"""ModernBERT backbone model.

ModernBERT features Rotary Positional Embeddings (RoPE), GeGLU activations,
RMSNorm, and Alternating Attention (interleaving local and global layers).

Args:
vocabulary_size: int. The size of the token vocabulary.
hidden_dim: int. The size of the transformer hidden state.
intermediate_dim: int. The output dimension of the GeGLU MLP.
num_layers: int. The number of transformer layers.
num_heads: int. The number of attention heads.
local_attention_window: int. Window size for local attention layers.
Defaults to `128`.
global_attn_every_n_layers: int. Frequency of global attention layers.
dropout: float. Dropout probability.
rotary_max_wavelength: int. Max wavelength for RoPE.
layer_norm_epsilon: float. Epsilon for RMSNorm.
dtype: string or `keras.DTypePolicy`. The dtype of the layers.
"""

def __init__(
self,
vocabulary_size,
hidden_dim,
intermediate_dim,
num_layers,
num_heads,
local_attention_window=128,
global_attn_every_n_layers=3,
dropout=0.0,
rotary_max_wavelength=160000,
dtype=None,
layer_norm_epsilon=1e-5,
**kwargs,
):
# === Inputs ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
dtype=dtype,
name="token_embedding",
)

self.position_embedding = RotaryEmbedding(
max_wavelength=rotary_max_wavelength,
dtype=dtype,
name="rotary_embedding",
)
self.embeddings_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
rms_scaling=True,
name="embeddings_layer_norm",
)

# === Forward pass with dtype cast ===
x = self.token_embedding(token_id_input)
x = self.embeddings_layer_norm(x)
x = keras.layers.Activation(None, dtype=dtype, name="embeddings_cast")(x)

for i in range(num_layers):
is_global = (i + 1) % global_attn_every_n_layers == 0
current_window = None if is_global else local_attention_window

# Ensure dtype consistency inside encoder layer
x = ModernBertEncoderLayer(
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
num_heads=num_heads,
rotary_embedding=self.position_embedding,
local_attention_window=current_window,
dropout=dropout,
layer_norm_epsilon=layer_norm_epsilon,
dtype=dtype,
name=f"transformer_layer_{i}",
)(x, padding_mask=padding_mask_input)
# Force dtype consistency after each encoder layer
x = keras.layers.Activation(None, dtype=dtype, name=f"transformer_cast_{i}")(x)

sequence_output = layers.LayerNormalization(
epsilon=layer_norm_epsilon,
rms_scaling=True,
dtype=dtype,
name="final_norm",
)(x)
sequence_output = keras.layers.Activation(None, dtype=dtype, name="sequence_output_cast")(sequence_output)

super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

# === Config Storage ===
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.local_attention_window = local_attention_window
self.global_attn_every_n_layers = global_attn_every_n_layers
self.dropout = dropout
self.rotary_max_wavelength = rotary_max_wavelength
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
config.update({
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"local_attention_window": self.local_attention_window,
"global_attn_every_n_layers": self.global_attn_every_n_layers,
"dropout": self.dropout,
"rotary_max_wavelength": self.rotary_max_wavelength,
"layer_norm_epsilon": self.layer_norm_epsilon,
})
return config

@classmethod
def from_config(cls, config):
return cls(**config)
53 changes: 53 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import keras
import pytest
from keras import ops

from keras_hub.src.models.modernbert.modernbert_backbone import (
ModernBertBackbone
)
from keras_hub.src.tests.test_case import TestCase

class ModernBertBackboneTest(TestCase):
"""Tests for ModernBERT backbone."""

def setUp(self):
"""Set up a small configuration for testing."""
self.init_kwargs = {
"vocabulary_size": 10,
"hidden_dim": 8,
"intermediate_dim": 64,
"num_layers": 2,
"num_heads": 4,
"local_attention_window": 128,
"global_attn_every_n_layers": 2,
"dropout": 0.0,
}
Comment on lines 15 to 24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The init_kwargs dictionary is missing the required local_attention_window argument for instantiating ModernBertBackbone. This will cause a TypeError when running the tests. Please add this argument to the dictionary.

        self.init_kwargs = {
            "vocabulary_size": 10,
            "num_layers": 2,
            "num_heads": 4,
            "hidden_dim": 8,
            "intermediate_dim": 32,
            "local_attention_window": 128,
        }

self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"padding_mask": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
"""Test backbone forward pass and standard KerasHub lifecycle.

This validates:
1. Forward pass with the given input data.
2. Config serialization (get_config/from_config).
3. Model saving and loading via the `.keras` format.
4. Backend-agnostic execution.
"""
self.run_backbone_test(
cls=ModernBertBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
)

@pytest.mark.large
def test_saved_model(self):
"""Test that the model can be saved and loaded in the .keras format."""
self.run_model_saving_test(
cls=ModernBertBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading