-
Notifications
You must be signed in to change notification settings - Fork 329
Adding ModernBert model #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Adding ModernBert model #2477
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| import keras | ||
| from keras import layers | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.layers.modeling.reversible_embedding import ReversibleEmbedding | ||
| from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
| from keras_hub.src.models.modernbert.modernbert_layers import ( | ||
| ModernBertEncoderLayer, | ||
| ) | ||
| from keras_hub.src.models.backbone import Backbone | ||
|
|
||
| @keras_hub_export("keras_hub.models.ModernBertBackbone") | ||
| class ModernBertBackbone(Backbone): | ||
| """ModernBERT backbone model. | ||
|
|
||
| ModernBERT features Rotary Positional Embeddings (RoPE), GeGLU activations, | ||
| RMSNorm, and Alternating Attention. This backbone is designed to be used | ||
| with `ModernBertTokenizer`. | ||
|
|
||
| Args: | ||
| vocabulary_size: int. The size of the token vocabulary. | ||
| hidden_dim: int. The size of the transformer hidden state. | ||
| intermediate_dim: int. The output dimension of the GeGLU MLP. | ||
| num_layers: int. The number of transformer layers. | ||
| num_heads: int. The number of attention heads. | ||
| dropout: float. Dropout probability. | ||
| local_attention_window: int. Window size for local layers (default 128). | ||
| global_attn_every_n_layers: int. Frequency of global attention (default 3). | ||
| rotary_max_wavelength: int. Max wavelength for RoPE (default 160000). | ||
| layer_norm_epsilon: float. Epsilon for RMSNorm (default 1e-5). | ||
| dtype: string or `keras.mixed_precision.DTypePolicy`. Data type. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| vocabulary_size, | ||
| hidden_dim, | ||
| intermediate_dim, | ||
| num_layers, | ||
| num_heads, | ||
| local_attention_window, | ||
| global_attn_every_n_layers=3, | ||
| dropout=0.0, | ||
| rotary_max_wavelength=160000, | ||
| layer_norm_epsilon=1e-5, | ||
| dtype=None, | ||
| **kwargs, | ||
| ): | ||
| # === Layers === | ||
| self.token_embedding = ReversibleEmbedding( | ||
| input_dim=vocabulary_size, | ||
| output_dim=hidden_dim, | ||
| embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02), | ||
| name="token_embedding" | ||
| ) | ||
| self.rotary_embedding = RotaryEmbedding( | ||
| max_wavelength=rotary_max_wavelength, | ||
| name="rotary_embedding" | ||
| ) | ||
|
|
||
| # ModernBERT uses RMSNorm (no additive bias, rms_scaling=True) | ||
| self.embeddings_layer_norm = layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, | ||
| rms_scaling=True, | ||
| name="embeddings_layer_norm" | ||
| ) | ||
|
|
||
| self.transformer_layers = [] | ||
| # Alternating Attention Logic: | ||
| # Every n-th layer is Global, others are Local. | ||
| for i in range(num_layers): | ||
| # Decide between Global Attention (None window) or Local Attention | ||
| is_global = (i + 1) % global_attn_every_n_layers == 0 | ||
| current_window = None if is_global else local_attention_window | ||
|
|
||
| self.transformer_layers.append( | ||
| ModernBertEncoderLayer( | ||
| hidden_dim=hidden_dim, | ||
| intermediate_dim=intermediate_dim, | ||
| num_heads=num_heads, | ||
| rotary_embedding=self.rotary_embedding, | ||
| local_attention_window=current_window, | ||
| dropout=dropout, | ||
| layer_norm_epsilon=layer_norm_epsilon, | ||
| name=f"transformer_layer_{i}", | ||
| ) | ||
| ) | ||
|
|
||
| self.final_norm = layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, | ||
| rms_scaling=True, | ||
| name="final_norm" | ||
| ) | ||
|
|
||
| # === Functional Model === | ||
| token_id_input = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
| padding_mask_input = keras.Input(shape=(None,), dtype="int32", name="padding_mask") | ||
|
|
||
| x = self.token_embedding(token_id_input) | ||
| x = self.embeddings_layer_norm(x) | ||
| for layer in self.transformer_layers: | ||
| x = layer(x, padding_mask=padding_mask_input) | ||
| sequence_output = self.final_norm(x) | ||
|
|
||
| # Instantiate using Functional API Model constructor | ||
| super().__init__( | ||
| inputs={"token_ids": token_id_input, "padding_mask": padding_mask_input}, | ||
| outputs=sequence_output, | ||
| dtype=dtype, | ||
| **kwargs | ||
| ) | ||
|
|
||
| # === Config === | ||
| self.vocabulary_size = vocabulary_size | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.num_layers = num_layers | ||
| self.num_heads = num_heads | ||
| self.local_attention_window = local_attention_window | ||
| self.global_attn_every_n_layers = global_attn_every_n_layers | ||
| self.dropout = dropout | ||
| self.rotary_max_wavelength = rotary_max_wavelength | ||
| self.layer_norm_epsilon = layer_norm_epsilon | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update({ | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "num_layers": self.num_layers, | ||
| "num_heads": self.num_heads, | ||
| "local_attention_window": self.local_attention_window, | ||
| "global_attn_every_n_layers": self.global_attn_every_n_layers, | ||
| "dropout": self.dropout, | ||
| "rotary_max_wavelength": self.rotary_max_wavelength, | ||
| "layer_norm_epsilon": self.layer_norm_epsilon, | ||
| }) | ||
| return config | ||
53 changes: 53 additions & 0 deletions
53
keras_hub/src/models/modernbert/modernbert_backbone_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import pytest | ||
| from keras import ops | ||
|
|
||
| from keras_hub.src.models.modernbert.modernbert_backbone import ( | ||
| ModernBertBackbone, | ||
| ) | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class ModernBertBackboneTest(TestCase): | ||
| def setUp(self): | ||
| self.init_kwargs = { | ||
| "vocabulary_size": 10, | ||
| "num_layers": 2, | ||
| "num_heads": 4, | ||
| "hidden_dim": 8, | ||
| "intermediate_dim": 32, | ||
| } | ||
| self.input_data = { | ||
| "token_ids": ops.ones((2, 5), dtype="int32"), | ||
| "padding_mask": ops.ones((2, 5), dtype="int32"), | ||
| } | ||
|
|
||
| def test_backbone_basics(self): | ||
| self.run_backbone_test( | ||
| cls=ModernBertBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output_shape=(2, 5, 8), | ||
| ) | ||
|
|
||
| @pytest.mark.large | ||
| def test_saved_model(self): | ||
| """ | ||
| Verify the model can be saved & loaded accurately. | ||
| """ | ||
| self.run_model_saving_test( | ||
| cls=ModernBertBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| ) | ||
|
|
||
| def test_mixed_precision(self): | ||
| """ | ||
| Verify the backbone works correctly with mixed precision policies. | ||
| """ | ||
| self.run_backbone_test( | ||
| cls=ModernBertBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output_shape=(2, 5, 8), | ||
| run_mixed_precision_check=True, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| import keras | ||
| from keras import layers | ||
| from keras import ops | ||
| from keras_hub.src.utils.keras_utils import gelu_approximate | ||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_hub") | ||
| class ModernBertMLP(layers.Layer): | ||
| """ModernBERT MLP block using Gated Linear Units (GeGLU). | ||
|
|
||
| Implements: output = wo(activation(wi_0(x)) * wi_1(x)). | ||
|
|
||
| Args: | ||
| hidden_dim: int. Input or output dimensionality. | ||
| intermediate_dim: int. Inner gated projection dimensionality. | ||
| activation: function. Activation function (default: gelu_approximate). | ||
| """ | ||
| def __init__(self, hidden_dim, intermediate_dim, activation=gelu_approximate, **kwargs): | ||
| super().__init__(**kwargs) | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.activation = keras.activations.get(activation) | ||
|
|
||
| self.wi_0 = layers.Dense(intermediate_dim, use_bias=False, name="wi_0") | ||
| self.wi_1 = layers.Dense(intermediate_dim, use_bias=False, name="wi_1") | ||
| self.wo = layers.Dense(hidden_dim, use_bias=False, name="wo") | ||
|
|
||
| def call(self, x): | ||
| return self.wo(self.activation(self.wi_0(x)) * self.wi_1(x)) | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update({ | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "activation": keras.activations.serialize(self.activation), | ||
| }) | ||
| return config | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_hub") | ||
| class ModernBertAttention(layers.Layer): | ||
| """ModernBERT Attention with RoPE and Alternating Window support. | ||
|
|
||
| Supports both global and local sliding window attention. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| hidden_dim, | ||
| num_heads, | ||
| rotary_embedding=None, | ||
| local_attention_window=None, | ||
| **kwargs | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.num_heads = num_heads | ||
| self.hidden_dim = hidden_dim | ||
| self.head_dim = hidden_dim // num_heads | ||
| self.rotary_embedding = rotary_embedding | ||
| self.local_attention_window = local_attention_window | ||
|
|
||
| self.qkv = layers.Dense(hidden_dim * 3, use_bias=False, name="Wqkv") | ||
| self.out_dense = layers.Dense(hidden_dim, use_bias=False, name="Wo") | ||
|
|
||
| def _get_sliding_window_mask(self, seq_len): | ||
| idx = ops.arange(seq_len) | ||
| distance = ops.abs(idx[:, None] - idx[None, :]) | ||
| mask = distance <= (self.local_attention_window // 2) | ||
| return ops.cast(mask, dtype="float32") | ||
|
|
||
| def call(self, x, padding_mask=None): | ||
| batch_size, seq_len = ops.shape(x)[0], ops.shape(x)[1] | ||
|
|
||
| qkv = self.qkv(x) | ||
| qkv = ops.reshape(qkv, (batch_size, seq_len, 3, self.num_heads, self.head_dim)) | ||
| q, k, v = ops.unstack(qkv, axis=2) | ||
|
|
||
| if self.rotary_embedding: | ||
| q, k = self.rotary_embedding(q), self.rotary_embedding(k) | ||
|
|
||
| q = ops.transpose(q, (0, 2, 1, 3)) | ||
| k = ops.transpose(k, (0, 2, 3, 1)) | ||
| v = ops.transpose(v, (0, 2, 1, 3)) | ||
|
|
||
| scores = ops.matmul(q, k) / ops.sqrt(ops.cast(self.head_dim, x.dtype)) | ||
|
|
||
| # ==== Sliding Window Mask ==== | ||
| if self.local_attention_window is not None: | ||
| sw_mask = self._get_sliding_window_mask(seq_len) | ||
| scores += (1.0 - sw_mask[None, None, :, :]) * -1e9 | ||
|
|
||
| if padding_mask is not None: | ||
| p_mask = ops.cast(padding_mask, x.dtype) | ||
| p_mask = p_mask[:, None, None, :] | ||
| scores += (1.0 - p_mask) * -1e9 | ||
|
|
||
| attn = ops.softmax(scores, axis=-1) | ||
| out = ops.matmul(attn, v) | ||
| out = ops.transpose(out, (0, 2, 1, 3)) | ||
| out = ops.reshape(out, (batch_size, seq_len, self.hidden_dim)) | ||
| return self.out_dense(out) | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update({ | ||
| "num_heads": self.num_heads, | ||
| "hidden_dim": self.hidden_dim, | ||
| "local_attention_window": self.local_attention_window, | ||
| }) | ||
| return config | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_hub") | ||
| class ModernBertEncoderLayer(layers.Layer): | ||
| """ | ||
| ModernBERT Encoder Layer. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| hidden_dim, | ||
| intermediate_dim, | ||
| num_heads, | ||
| rotary_embedding=None, | ||
| local_attention_window=None, | ||
| dropout=0.0, | ||
| layer_norm_epsilon=1e-5, | ||
| **kwargs | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.num_heads = num_heads | ||
| self.local_attention_window = local_attention_window | ||
| self.dropout = dropout | ||
| self.layer_norm_epsilon = layer_norm_epsilon | ||
|
|
||
| self.attn_norm = layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, rms_scaling=True, name="attn_norm" | ||
| ) | ||
| self.attn = ModernBertAttention( | ||
| hidden_dim, num_heads, rotary_embedding, local_attention_window, name="attn" | ||
| ) | ||
| self.mlp_norm = layers.LayerNormalization( | ||
| epsilon=layer_norm_epsilon, rms_scaling=True, name="mlp_norm" | ||
| ) | ||
| self.mlp = ModernBertMLP(hidden_dim, intermediate_dim, name="mlp") | ||
| self.dropout_layer = layers.Dropout(dropout) | ||
|
|
||
| def call(self, x, padding_mask=None): | ||
| res = x | ||
| x = self.attn_norm(x) | ||
| x = self.attn(x, padding_mask=padding_mask) | ||
| x = res + self.dropout_layer(x) | ||
|
|
||
| res = x | ||
| x = self.mlp_norm(x) | ||
| x = self.mlp(x) | ||
| x = res + self.dropout_layer(x) | ||
| return x | ||
|
Comment on lines
148
to
158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def call(self, x, padding_mask=None):
# Attention Residual path
x = x + self.dropout_layer(self.attn(self.attn_norm(x), padding_mask=padding_mask))
# MLP Residual path
x = x + self.dropout_layer(self.mlp(self.mlp_norm(x)))
return x |
||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update({ | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "num_heads": self.num_heads, | ||
| "local_attention_window": self.local_attention_window, | ||
| "dropout": self.dropout, | ||
| "layer_norm_epsilon": self.layer_norm_epsilon, | ||
| }) | ||
| return config | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These attributes are already assigned at the beginning of the
__init__method (lines 63-70). This block of code is redundant and can be removed to improve maintainability.