-
Notifications
You must be signed in to change notification settings - Fork 328
ModernBERT Implementation #2518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
maitry63
wants to merge
4
commits into
keras-team:master
Choose a base branch
from
maitry63:modernbert-implementation
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| import keras | ||
| from keras import layers | ||
| from keras import ops | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
| ReversibleEmbedding, | ||
| ) | ||
| from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
| from keras_hub.src.models.backbone import Backbone | ||
|
|
||
| from keras_hub.src.models.modernbert.modernbert_layers import ( | ||
| ModernBertEncoderLayer, | ||
| ) | ||
|
|
||
| from keras_hub.src.utils.keras_utils import gelu_approximate | ||
|
|
||
| @keras_hub_export("keras_hub.models.ModernBertBackbone") | ||
| class ModernBertBackbone(Backbone): | ||
| """ModernBERT backbone model. | ||
|
|
||
| ModernBERT features Rotary Positional Embeddings (RoPE), GeGLU activations, | ||
| RMSNorm, and Alternating Attention (interleaving local and global layers). | ||
|
|
||
| Args: | ||
| vocabulary_size: int. The size of the token vocabulary. | ||
| hidden_dim: int. The size of the transformer hidden state. | ||
| intermediate_dim: int. The output dimension of the GeGLU MLP. | ||
| num_layers: int. The number of transformer layers. | ||
| num_heads: int. The number of attention heads. | ||
| local_attention_window: int. Window size for local attention layers. | ||
| Defaults to `128`. | ||
| global_attn_every_n_layers: int. Frequency of global attention layers. | ||
| dropout: float. Dropout probability. | ||
| rotary_max_wavelength: int. Max wavelength for RoPE. | ||
| layer_norm_epsilon: float. Epsilon for RMSNorm. | ||
| dtype: string or `keras.DTypePolicy`. The dtype of the layers. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocabulary_size, | ||
| hidden_dim, | ||
| intermediate_dim, | ||
| num_layers, | ||
| num_heads, | ||
| local_attention_window=128, | ||
| global_attn_every_n_layers=3, | ||
| dropout=0.0, | ||
| rotary_max_wavelength=160000, | ||
| dtype=None, | ||
| layer_norm_epsilon=1e-5, | ||
| **kwargs, | ||
| ): | ||
| # === Inputs === | ||
| token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="token_ids" | ||
| ) | ||
| padding_mask_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="padding_mask" | ||
| ) | ||
|
|
||
| # === Layers === | ||
| self.token_embedding = ReversibleEmbedding( | ||
| input_dim=vocabulary_size, | ||
| output_dim=hidden_dim, | ||
| embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02), | ||
| dtype=dtype, | ||
| name="token_embedding", | ||
| ) | ||
|
|
||
| self.position_embedding = RotaryEmbedding( | ||
| max_wavelength=rotary_max_wavelength, | ||
| dtype=dtype, | ||
| name="rotary_embedding", | ||
| ) | ||
| self.embeddings_layer_norm = keras.layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, | ||
| dtype=dtype, | ||
| rms_scaling=True, | ||
| name="embeddings_layer_norm", | ||
| ) | ||
|
|
||
| # === Forward pass with dtype cast === | ||
| x = self.token_embedding(token_id_input) | ||
| x = self.embeddings_layer_norm(x) | ||
| x = keras.layers.Activation(None, dtype=dtype, name="embeddings_cast")(x) | ||
|
|
||
| for i in range(num_layers): | ||
| is_global = (i + 1) % global_attn_every_n_layers == 0 | ||
| current_window = None if is_global else local_attention_window | ||
|
|
||
| # Ensure dtype consistency inside encoder layer | ||
| x = ModernBertEncoderLayer( | ||
| hidden_dim=hidden_dim, | ||
| intermediate_dim=intermediate_dim, | ||
| num_heads=num_heads, | ||
| rotary_embedding=self.position_embedding, | ||
| local_attention_window=current_window, | ||
| dropout=dropout, | ||
| layer_norm_epsilon=layer_norm_epsilon, | ||
| dtype=dtype, | ||
| name=f"transformer_layer_{i}", | ||
| )(x, padding_mask=padding_mask_input) | ||
| # Force dtype consistency after each encoder layer | ||
| x = keras.layers.Activation(None, dtype=dtype, name=f"transformer_cast_{i}")(x) | ||
|
|
||
| sequence_output = layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, | ||
| rms_scaling=True, | ||
| dtype=dtype, | ||
| name="final_norm", | ||
| )(x) | ||
| sequence_output = keras.layers.Activation(None, dtype=dtype, name="sequence_output_cast")(sequence_output) | ||
|
|
||
| super().__init__( | ||
| inputs={ | ||
| "token_ids": token_id_input, | ||
| "padding_mask": padding_mask_input, | ||
| }, | ||
| outputs=sequence_output, | ||
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # === Config Storage === | ||
| self.vocabulary_size = vocabulary_size | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.num_layers = num_layers | ||
| self.num_heads = num_heads | ||
| self.local_attention_window = local_attention_window | ||
| self.global_attn_every_n_layers = global_attn_every_n_layers | ||
| self.dropout = dropout | ||
| self.rotary_max_wavelength = rotary_max_wavelength | ||
| self.layer_norm_epsilon = layer_norm_epsilon | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update({ | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "num_layers": self.num_layers, | ||
| "num_heads": self.num_heads, | ||
| "local_attention_window": self.local_attention_window, | ||
| "global_attn_every_n_layers": self.global_attn_every_n_layers, | ||
| "dropout": self.dropout, | ||
| "rotary_max_wavelength": self.rotary_max_wavelength, | ||
| "layer_norm_epsilon": self.layer_norm_epsilon, | ||
| }) | ||
| return config | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config): | ||
| return cls(**config) |
53 changes: 53 additions & 0 deletions
53
keras_hub/src/models/modernbert/modernbert_backbone_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import keras | ||
| import pytest | ||
| from keras import ops | ||
|
|
||
| from keras_hub.src.models.modernbert.modernbert_backbone import ( | ||
| ModernBertBackbone | ||
| ) | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|
|
||
| class ModernBertBackboneTest(TestCase): | ||
| """Tests for ModernBERT backbone.""" | ||
|
|
||
| def setUp(self): | ||
| """Set up a small configuration for testing.""" | ||
| self.init_kwargs = { | ||
| "vocabulary_size": 10, | ||
| "hidden_dim": 8, | ||
| "intermediate_dim": 64, | ||
| "num_layers": 2, | ||
| "num_heads": 4, | ||
| "local_attention_window": 128, | ||
| "global_attn_every_n_layers": 2, | ||
| "dropout": 0.0, | ||
| } | ||
| self.input_data = { | ||
| "token_ids": ops.ones((2, 5), dtype="int32"), | ||
| "padding_mask": ops.ones((2, 5), dtype="int32"), | ||
| } | ||
|
|
||
| def test_backbone_basics(self): | ||
| """Test backbone forward pass and standard KerasHub lifecycle. | ||
|
|
||
| This validates: | ||
| 1. Forward pass with the given input data. | ||
| 2. Config serialization (get_config/from_config). | ||
| 3. Model saving and loading via the `.keras` format. | ||
| 4. Backend-agnostic execution. | ||
| """ | ||
| self.run_backbone_test( | ||
| cls=ModernBertBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output_shape=(2, 5, 8), | ||
| ) | ||
|
|
||
| @pytest.mark.large | ||
| def test_saved_model(self): | ||
| """Test that the model can be saved and loaded in the .keras format.""" | ||
| self.run_model_saving_test( | ||
| cls=ModernBertBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
init_kwargsdictionary is missing the requiredlocal_attention_windowargument for instantiatingModernBertBackbone. This will cause aTypeErrorwhen running the tests. Please add this argument to the dictionary.