diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 0a71dbcace..4a413ff165 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -369,6 +369,12 @@ from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone as ModernBertBackbone, +) +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer as ModernBertTokenizer, +) from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( MoonshineAudioToText as MoonshineAudioToText, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..428aa62d7f 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -58,6 +58,9 @@ from keras_hub.src.models.mixtral.mixtral_tokenizer import ( MixtralTokenizer as MixtralTokenizer, ) +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer as ModernBertTokenizer, +) from keras_hub.src.models.moonshine.moonshine_tokenizer import ( MoonshineTokenizer as MoonshineTokenizer, ) diff --git a/keras_hub/src/models/modernbert/__init__.py b/keras_hub/src/models/modernbert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/modernbert/modernbert_backbone.py b/keras_hub/src/models/modernbert/modernbert_backbone.py new file mode 100644 index 0000000000..4ca9eb92b6 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_backbone.py @@ -0,0 +1,155 @@ +import keras +from keras import layers +from keras import ops +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.backbone import Backbone + +from keras_hub.src.models.modernbert.modernbert_layers import ( + ModernBertEncoderLayer, +) + +from keras_hub.src.utils.keras_utils import gelu_approximate + +@keras_hub_export("keras_hub.models.ModernBertBackbone") +class ModernBertBackbone(Backbone): + """ModernBERT backbone model. + + ModernBERT features Rotary Positional Embeddings (RoPE), GeGLU activations, + RMSNorm, and Alternating Attention (interleaving local and global layers). + + Args: + vocabulary_size: int. The size of the token vocabulary. + hidden_dim: int. The size of the transformer hidden state. + intermediate_dim: int. The output dimension of the GeGLU MLP. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads. + local_attention_window: int. Window size for local attention layers. + Defaults to `128`. + global_attn_every_n_layers: int. Frequency of global attention layers. + dropout: float. Dropout probability. + rotary_max_wavelength: int. Max wavelength for RoPE. + layer_norm_epsilon: float. Epsilon for RMSNorm. + dtype: string or `keras.DTypePolicy`. The dtype of the layers. + """ + + def __init__( + self, + vocabulary_size, + hidden_dim, + intermediate_dim, + num_layers, + num_heads, + local_attention_window=128, + global_attn_every_n_layers=3, + dropout=0.0, + rotary_max_wavelength=160000, + dtype=None, + layer_norm_epsilon=1e-5, + **kwargs, + ): + # === Inputs === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=0.02), + dtype=dtype, + name="token_embedding", + ) + + self.position_embedding = RotaryEmbedding( + max_wavelength=rotary_max_wavelength, + dtype=dtype, + name="rotary_embedding", + ) + self.embeddings_layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + rms_scaling=True, + name="embeddings_layer_norm", + ) + + # === Forward pass with dtype cast === + x = self.token_embedding(token_id_input) + x = self.embeddings_layer_norm(x) + x = keras.layers.Activation(None, dtype=dtype, name="embeddings_cast")(x) + + for i in range(num_layers): + is_global = (i + 1) % global_attn_every_n_layers == 0 + current_window = None if is_global else local_attention_window + + # Ensure dtype consistency inside encoder layer + x = ModernBertEncoderLayer( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_heads=num_heads, + rotary_embedding=self.position_embedding, + local_attention_window=current_window, + dropout=dropout, + layer_norm_epsilon=layer_norm_epsilon, + dtype=dtype, + name=f"transformer_layer_{i}", + )(x, padding_mask=padding_mask_input) + # Force dtype consistency after each encoder layer + x = keras.layers.Activation(None, dtype=dtype, name=f"transformer_cast_{i}")(x) + + sequence_output = layers.LayerNormalization( + epsilon=layer_norm_epsilon, + rms_scaling=True, + dtype=dtype, + name="final_norm", + )(x) + sequence_output = keras.layers.Activation(None, dtype=dtype, name="sequence_output_cast")(sequence_output) + + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config Storage === + self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.local_attention_window = local_attention_window + self.global_attn_every_n_layers = global_attn_every_n_layers + self.dropout = dropout + self.rotary_max_wavelength = rotary_max_wavelength + self.layer_norm_epsilon = layer_norm_epsilon + + def get_config(self): + config = super().get_config() + config.update({ + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "local_attention_window": self.local_attention_window, + "global_attn_every_n_layers": self.global_attn_every_n_layers, + "dropout": self.dropout, + "rotary_max_wavelength": self.rotary_max_wavelength, + "layer_norm_epsilon": self.layer_norm_epsilon, + }) + return config + + @classmethod + def from_config(cls, config): + return cls(**config) \ No newline at end of file diff --git a/keras_hub/src/models/modernbert/modernbert_backbone_test.py b/keras_hub/src/models/modernbert/modernbert_backbone_test.py new file mode 100644 index 0000000000..16b9221608 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_backbone_test.py @@ -0,0 +1,53 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone +) +from keras_hub.src.tests.test_case import TestCase + +class ModernBertBackboneTest(TestCase): + """Tests for ModernBERT backbone.""" + + def setUp(self): + """Set up a small configuration for testing.""" + self.init_kwargs = { + "vocabulary_size": 10, + "hidden_dim": 8, + "intermediate_dim": 64, + "num_layers": 2, + "num_heads": 4, + "local_attention_window": 128, + "global_attn_every_n_layers": 2, + "dropout": 0.0, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + """Test backbone forward pass and standard KerasHub lifecycle. + + This validates: + 1. Forward pass with the given input data. + 2. Config serialization (get_config/from_config). + 3. Model saving and loading via the `.keras` format. + 4. Backend-agnostic execution. + """ + self.run_backbone_test( + cls=ModernBertBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + """Test that the model can be saved and loaded in the .keras format.""" + self.run_model_saving_test( + cls=ModernBertBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) \ No newline at end of file diff --git a/keras_hub/src/models/modernbert/modernbert_layers.py b/keras_hub/src/models/modernbert/modernbert_layers.py new file mode 100644 index 0000000000..8fad193ce9 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_layers.py @@ -0,0 +1,194 @@ +import keras +from keras import layers +from keras import ops +from keras_hub.src.utils.keras_utils import gelu_approximate + +@keras.utils.register_keras_serializable(package="keras_hub") +class ModernBertMLP(layers.Layer): + """ModernBERT MLP block using Gated Linear Units (GeGLU). + + Implements: output = wo(activation(wi_0(x)) * wi_1(x)). + + Args: + hidden_dim: int. Input or output dimensionality. + intermediate_dim: int. Inner gated projection dimensionality. + activation: function. Activation function (default: gelu_approximate). + """ + + def __init__(self, hidden_dim, intermediate_dim, activation=gelu_approximate, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.activation = keras.activations.get(activation) + + self.wi_0 = layers.Dense(intermediate_dim, use_bias=False, dtype="float16", name="wi_0") + self.wi_1 = layers.Dense(intermediate_dim, use_bias=False, dtype="float16", name="wi_1") + self.wo = layers.Dense(hidden_dim, use_bias=False, dtype="float16", name="wo") + + def call(self, x): + compute_dtype = self.compute_dtype + # Gated Linear Unit math - ensure consistent dtypes + gate = ops.cast(self.wi_0(x), compute_dtype) + gate = self.activation(gate) + value = ops.cast(self.wi_1(x), compute_dtype) + return self.wo(gate * value) + + def compute_output_spec(self, x_spec, **kwargs): + # MLP: hidden_dim -> intermediate_dim -> hidden_dim + return keras.KerasTensor( + shape=x_spec.shape, + dtype=self.compute_dtype + ) + + + def get_config(self): + config = super().get_config() + config.update({ + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "activation": keras.activations.serialize(self.activation), + }) + return config + + +@keras.utils.register_keras_serializable(package="keras_hub") +class ModernBertAttention(layers.Layer): + """ModernBERT Attention with RoPE and Alternating Window support. + + Supports both global and local sliding window attention. + """ + + def __init__(self, hidden_dim, num_heads, rotary_embedding=None, local_attention_window=None, **kwargs): + super().__init__(**kwargs) + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.head_dim = hidden_dim // num_heads + self.rotary_embedding = rotary_embedding + self.local_attention_window = local_attention_window + + self.qkv = layers.Dense(hidden_dim * 3, use_bias=False, dtype="float16", name="Wqkv") + self.out_dense = layers.Dense(hidden_dim, use_bias=False, dtype="float16", name="Wo") + + def _get_sliding_window_mask(self, seq_len, dtype): + idx = ops.arange(seq_len) + distance = ops.abs(idx[:, None] - idx[None, :]) + mask = distance <= (self.local_attention_window // 2) + return ops.cast(mask, dtype=dtype) + + def call(self, x, padding_mask=None): + batch_size, seq_len = ops.shape(x)[0], ops.shape(x)[1] + compute_dtype = self.compute_dtype + + qkv = self.qkv(x) + qkv = ops.reshape(qkv, (batch_size, seq_len, 3, self.num_heads, self.head_dim)) + q, k, v = ops.unstack(qkv, axis=2) + + if self.rotary_embedding: + q = ops.cast(self.rotary_embedding(q), compute_dtype) + k = ops.cast(self.rotary_embedding(k), compute_dtype) + + q = ops.transpose(q, (0, 2, 1, 3)) + k = ops.transpose(k, (0, 2, 3, 1)) + v = ops.transpose(v, (0, 2, 1, 3)) + + # Scaling logic + scale = ops.cast(ops.sqrt(ops.cast(self.head_dim, compute_dtype)), compute_dtype) + scores = ops.matmul(q, k) / scale + + mask_value = ops.cast(-1e4, compute_dtype) + one = ops.cast(1.0, compute_dtype) + + if self.local_attention_window is not None: + sw_mask = self._get_sliding_window_mask(seq_len, compute_dtype) + scores += (one - sw_mask[None, None, :, :]) * mask_value + + if padding_mask is not None: + p_mask = ops.cast(padding_mask, compute_dtype) + p_mask = p_mask[:, None, None, :] + scores += (one - p_mask) * mask_value + + attn = ops.softmax(scores, axis=-1) + attn = ops.cast(attn, compute_dtype) + + out = ops.matmul(attn, v) + out = ops.transpose(out, (0, 2, 1, 3)) + out = ops.reshape(out, (batch_size, seq_len, self.hidden_dim)) + return self.out_dense(out) + + def compute_output_spec(self, x_spec, padding_mask_spec=None, **kwargs): + # Attention preserves input shape and uses compute_dtype + return keras.KerasTensor(shape=x_spec.shape, dtype=self.compute_dtype) + + + def get_config(self): + config = super().get_config() + config.update({ + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "local_attention_window": self.local_attention_window, + }) + return config + +@keras.utils.register_keras_serializable(package="keras_hub") +class ModernBertEncoderLayer(layers.Layer): + def __init__(self, hidden_dim, intermediate_dim, num_heads, + rotary_embedding=None, local_attention_window=None, + dropout=0.0, layer_norm_epsilon=1e-5, **kwargs): + + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.local_attention_window = local_attention_window + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + + self.attn_norm = layers.LayerNormalization(epsilon=layer_norm_epsilon,dtype="float16", name="attn_norm") + self.attn = ModernBertAttention(hidden_dim, num_heads, rotary_embedding,local_attention_window, name="attn") + self.mlp_norm = layers.LayerNormalization(epsilon=layer_norm_epsilon,dtype="float16", name="mlp_norm") + self.mlp = ModernBertMLP(hidden_dim,intermediate_dim, name="mlp") + + + self.dropout_layer = layers.Dropout(dropout, dtype="float16") + + + def call(self, x, padding_mask=None): + compute_dtype = self.compute_dtype + + # Self-Attention block + attn_res = self.attn_norm(x) + attn_out = self.attn(attn_res, padding_mask=padding_mask) + x = attn_res + self.dropout_layer(attn_out) + + # MLP block + mlp_res = self.mlp_norm(x) + mlp_out = self.mlp(mlp_res) + x = mlp_res + self.dropout_layer(mlp_out) + + return x + + + def compute_output_spec(self, x_spec, **kwargs): + padding_mask_spec = kwargs.get('padding_mask_spec', None) + + # Attention path + attn_res_spec = self.attn_norm.compute_output_spec(x_spec) + attn_spec = self.attn.compute_output_spec(attn_res_spec, padding_mask_spec=padding_mask_spec) + x_spec = self.dropout_layer.compute_output_spec(attn_spec) + + # MLP path + mlp_res_spec = self.mlp_norm.compute_output_spec(x_spec) + mlp_spec = self.mlp.compute_output_spec(mlp_res_spec) + return self.dropout_layer.compute_output_spec(mlp_spec) + + def get_config(self): + config = super().get_config() + config.update({ + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "local_attention_window": self.local_attention_window, + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + }) + return config \ No newline at end of file diff --git a/keras_hub/src/models/modernbert/modernbert_layers_test.py b/keras_hub/src/models/modernbert/modernbert_layers_test.py new file mode 100644 index 0000000000..ea5d7d15fe --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_layers_test.py @@ -0,0 +1,61 @@ +import numpy as np +import os +import pytest +import keras +from keras import ops +from keras_hub.src.tests.test_case import TestCase + +from keras_hub.src.models.modernbert.modernbert_layers import ( + ModernBertAttention, ModernBertEncoderLayer, +) +from keras_hub.src.models.modernbert.modernbert_masked_lm import ( + ModernBertMaskedLM, +) + +class ModernBertLayersTest(TestCase): + def test_attention_masking_logic(self): + """Verify that the attention layer correctly handles padding masks.""" + layer = ModernBertAttention(hidden_dim=16, num_heads=2) + x = ops.ones((1, 4, 16)) + mask = ops.convert_to_tensor([[1, 1, 0, 0]], dtype="int32") + output = layer(x, padding_mask=mask) + + self.assertFalse(np.any(np.isnan(output))) + self.assertEqual(output.shape, (1, 4, 16)) + + def test_serialization_attributes(self): + """Explicitly verify that custom attributes are restored.""" + layer = ModernBertEncoderLayer( + hidden_dim=16, + intermediate_dim=32, + num_heads=2, + local_attention_window=64 + ) + config = layer.get_config() + new_layer = ModernBertEncoderLayer.from_config(config) + self.assertEqual(new_layer.local_attention_window, 64) + self.assertEqual(new_layer.hidden_dim, 16) + + def test_sliding_window_mask_creation(self): + """Directly check the internal mask generation logic.""" + layer = ModernBertAttention( + hidden_dim=8, + num_heads=2, + local_attention_window=2 + ) + mask = layer._get_sliding_window_mask(seq_len=4, dtype="float32") + expected = [ + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 1, 1, 1], + [0, 0, 1, 1], + ] + self.assertAllClose(mask, expected) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ModernBertMaskedLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) \ No newline at end of file diff --git a/keras_hub/src/models/modernbert/modernbert_masked_lm.py b/keras_hub/src/models/modernbert/modernbert_masked_lm.py new file mode 100644 index 0000000000..d7c4154432 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_masked_lm.py @@ -0,0 +1,106 @@ +import keras +from keras import layers +from keras import ops +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.masked_lm import MaskedLM + +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) +from keras_hub.src.models.modernbert.modernbert_preprocessor import ( + ModernBertMaskedLMPreprocessor, +) + +@keras_hub_export("keras_hub.models.ModernBertMaskedLM") +class ModernBertMaskedLM(MaskedLM): + """ModernBERT Masked LM task model. + + The Masked LM model provides a prediction head for the Masked Language + Modeling task. It is composed of a `keras_hub.models.ModernBertBackbone` + and a prediction head which projects the backbone's hidden states back + to the vocabulary space. + + This model can be used for pre-training or fine-tuning on a specific + corpus. + + Args: + backbone: A `keras_hub.models.ModernBertBackbone` instance. + preprocessor: A `keras_hub.models.ModernBertMaskedLMPreprocessor` or + `None`. If `None`, this model will not handle input preprocessing. + **kwargs: Standard `keras.Model` arguments. + + Example: + ```python + import keras_hub + import numpy as np + + # Pre-trained backbone and preprocessor + tokenizer = keras_hub.models.ModernBertTokenizer( + vocabulary="vocab.json", + merges="merges.txt", + ) + preprocessor = keras_hub.models.ModernBertMaskedLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_hub.models.ModernBertBackbone( + vocabulary_size=50368, + hidden_dim=768, + intermediate_dim=1152, + num_layers=22, + num_heads=12, + ) + + # Instantiate the MaskedLM task model + masked_lm = keras_hub.models.ModernBertMaskedLM( + backbone=backbone, + preprocessor=preprocessor, + ) + + # Predict on raw text strings + raw_data = ["The quick brown fox [MASK] over the lazy dog."] + predictions = masked_lm.predict(raw_data) + ``` + """ + + backbone_cls = ModernBertBackbone + preprocessor_cls = ModernBertMaskedLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Inputs === + inputs = backbone.input + + mask_positions = keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ) + # Output shape: (batch_size, sequence_length, hidden_dim) + sequence_output = backbone(inputs) + + x = ops.take_along_axis( + sequence_output, + mask_positions[:, :, None], + axis=1 + ) + + # ModernBERT uses RMSNorm (LayerNormalization with rms_scaling=True) + x = layers.LayerNormalization( + epsilon=backbone.layer_norm_epsilon, + rms_scaling=True, + name="mlm_head_norm", + )(x) + + # Output shape: (batch_size, mask_selection_length, vocabulary_size) + logits = layers.Dense( + backbone.vocabulary_size, + kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02), + name="mlm_head_logits", + )(x) + + # === Initialize the MaskedLM base class === + super().__init__( + backbone=backbone, + preprocessor=preprocessor, + inputs={**inputs, "mask_positions": mask_positions}, + outputs=logits, + **kwargs, + ) \ No newline at end of file diff --git a/keras_hub/src/models/modernbert/modernbert_masked_lm_test.py b/keras_hub/src/models/modernbert/modernbert_masked_lm_test.py new file mode 100644 index 0000000000..52e9c246a6 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_masked_lm_test.py @@ -0,0 +1,68 @@ +import os +import pytest + +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer, +) +from keras_hub.src.models.modernbert.modernbert_masked_lm import ( + ModernBertMaskedLM, +) +from keras_hub.src.models.modernbert.modernbert_preprocessor import ( + ModernBertMaskedLMPreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +class ModernBertMaskedLMTest(TestCase): + def setUp(self): + self.vocab = ["<|padding|>", "<|endoftext|>", "[MASK]", "the", "quick", "brown"] + self.vocab += [f"token_{i}" for i in range(10)] + self.vocabulary = {token: i for i, token in enumerate(self.vocab)} + + self.merges = ["t h", "th e", "q u", "qu i", "qui ck"] + + self.tokenizer = ModernBertTokenizer( + vocabulary=self.vocabulary, + merges=self.merges, + ) + + self.preprocessor = ModernBertMaskedLMPreprocessor( + tokenizer=self.tokenizer, + sequence_length=10, + mask_selection_rate=0.2, + mask_selection_length=2, + ) + + self.backbone = ModernBertBackbone( + vocabulary_size=len(self.vocab), + hidden_dim=16, + intermediate_dim=32, + local_attention_window=128, + num_layers=2, + num_heads=2, + ) + + self.model = ModernBertMaskedLM( + backbone=self.backbone, + preprocessor=self.preprocessor, + ) + + @pytest.mark.extra_large + def test_fit(self): + # Verify the model can actually train (one step) + input_data = ["the quick brown", "the quick"] + self.model.compile(optimizer="adam", loss="sparse_categorical_crossentropy") + print(type(input_data)) + ds = self.model.preprocessor(input_data) + self.model.fit(ds, epochs=1) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ModernBertMaskedLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/modernbert/modernbert_preprocessor.py b/keras_hub/src/models/modernbert/modernbert_preprocessor.py new file mode 100644 index 0000000000..daf8d02a5c --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_preprocessor.py @@ -0,0 +1,105 @@ +import keras +from keras import ops +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import MaskedLMMaskGenerator +from keras_hub.src.layers.preprocessing.multi_segment_packer import MultiSegmentPacker + +@keras_hub_export("keras_hub.models.ModernBertMaskedLMPreprocessor") +class ModernBertMaskedLMPreprocessor(Preprocessor): + """ModernBERT Masked LM preprocessor. + + This class prepares raw strings for Masked Language Modeling (MLM) using + the ModernBERT architecture. It tokenizes the input, packs it with special + tokens, and generates masks for training. + + The output of this preprocessor is a tuple `(x, y, sw)`, where `x` is a + dictionary containing: + - `"token_ids"`: The masked token IDs. + - `"padding_mask"`: A mask for non-padding tokens. + - `"mask_positions"`: The indices of the tokens that were masked. + + `y` contains the original token IDs for the masked positions, and `sw` + contains the sample weights for the loss function. + + Args: + tokenizer: A `keras_hub.models.ModernBertTokenizer` instance. + sequence_length: int. The length of the packed sequence. + mask_selection_rate: float. The probability of masking a token. + mask_selection_length: int. The maximum number of tokens to mask per + sequence. + **kwargs: Standard `keras.layers.Layer` arguments. + + Examples: + ```python + # Load the preprocessor from a preset + preprocessor = keras_hub.models.ModernBertMaskedLMPreprocessor.from_preset( + "modernbert_base" + ) + + # Preprocess raw text + x, y, sw = preprocessor(["The quick brown fox jumps over the dog."]) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=512, + mask_selection_rate=0.15, + mask_selection_length=96, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + + self.packer = MultiSegmentPacker( + start_value=tokenizer.end_token_id, + end_value=tokenizer.end_token_id, + pad_value=tokenizer.pad_token_id, + sequence_length=sequence_length, + ) + + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=mask_selection_rate, + mask_selection_length=mask_selection_length, + vocabulary_size=tokenizer.vocabulary_size, + mask_token_id=tokenizer.mask_token_id, + unselectable_token_ids=[ + tokenizer.pad_token_id, + tokenizer.end_token_id, + ], + ) + self.sequence_length = sequence_length + + def call(self, x, y=None, sample_weight=None): + """Transform raw strings into masked token sequences.""" + x = self.tokenizer(x) + token_ids, padding_mask = self.packer(x) + + mask_data = self.masker(token_ids) + + x_masked = mask_data["token_ids"] + y_labels = mask_data["mask_ids"] + mask_positions = mask_data["mask_positions"] + + padding_mask = ops.cast(padding_mask, dtype="int32") + + return ( + { + "token_ids": x_masked, + "padding_mask": padding_mask, + "mask_positions": mask_positions, + }, + y_labels, + mask_data["mask_weights"], + ) + + def get_config(self): + config = super().get_config() + config.update({ + "sequence_length": self.sequence_length, + "mask_selection_rate": self.masker.mask_selection_rate, + "mask_selection_length": self.masker.mask_selection_length, + }) + return config diff --git a/keras_hub/src/models/modernbert/modernbert_preprocessor_test.py b/keras_hub/src/models/modernbert/modernbert_preprocessor_test.py new file mode 100644 index 0000000000..4cbeddf669 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_preprocessor_test.py @@ -0,0 +1,63 @@ +import os + +import pytest + +from keras_hub.src.models.modernbert.modernbert_preprocessor import ( + ModernBertMaskedLMPreprocessor, +) +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class ModernBertMaskedLMPreprocessorTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "hidden_dim": 8, + "intermediate_dim": 64, + "num_layers": 2, + "num_heads": 4, + "local_attention_window": 128, + "global_attn_every_n_layers": 2, + "dropout": 0.0, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=ModernBertMaskedLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]], + "mask_positions": [[1, 2, 3, 4]], + }, + [[5, 10, 6, 8]], + [[1.0, 1.0, 1.0, 1.0]], + ), + ) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = ModernBertMaskedLMPreprocessor( + self.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=4, + sequence_length=12, + ) + input_data = ["the quick brown fox"] + self.assertAllClose( + no_mask_preprocessor(input_data), + ( + { + "token_ids": [[1, 5, 10, 6, 8, 2, 0, 0, 0, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]], + "mask_positions": [[0, 0, 0, 0]], + }, + [[0, 0, 0, 0]], + [[0.0, 0.0, 0.0, 0.0]], + ), + ) diff --git a/keras_hub/src/models/modernbert/modernbert_tokenizer.py b/keras_hub/src/models/modernbert/modernbert_tokenizer.py new file mode 100644 index 0000000000..e622d199d5 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_tokenizer.py @@ -0,0 +1,35 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + +@keras_hub_export( + [ + "keras_hub.tokenizers.ModernBertTokenizer", + "keras_hub.models.ModernBertTokenizer", + ] +) +class ModernBertTokenizer(BytePairTokenizer): + backbone_cls = ModernBertBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("[PAD]", "pad_token") + self._add_special_token("[UNK]", "unk_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) \ No newline at end of file diff --git a/keras_hub/src/models/modernbert/modernbert_tokenizer_test.py b/keras_hub/src/models/modernbert/modernbert_tokenizer_test.py new file mode 100644 index 0000000000..b863d03724 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_tokenizer_test.py @@ -0,0 +1,35 @@ +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class ModernBertTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["[CLS]", "[PAD]", "[SEP]", "air", "Ġair", "plane", "Ġat"] + self.vocab += ["port", "[MASK]", "[UNK]"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + "[CLS] airplane at airport[SEP][PAD]", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=ModernBertTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], + expected_detokenize_output=[ + "[CLS] airplane at airport[SEP][PAD]", + " airplane airport", + ], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + ModernBertTokenizer(vocabulary=["a", "b", "c"], merges=[]) diff --git a/tools/checkpoint_conversion/convert_modernbert_checkpoints.py b/tools/checkpoint_conversion/convert_modernbert_checkpoints.py new file mode 100644 index 0000000000..e6ff5e9599 --- /dev/null +++ b/tools/checkpoint_conversion/convert_modernbert_checkpoints.py @@ -0,0 +1,120 @@ +"""Convert ModernBERT checkpoints. + +python tools/checkpoint_conversion/convert_modernbert_checkpoints.py \ + --preset modernbert_base +python tools/checkpoint_conversion/convert_modernbert_checkpoints.py \ + --preset modernbert_large +""" + +import json +import os + +import numpy as np +import requests +import transformers +from absl import app +from absl import flags + +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) + +PRESET_MAP = { + "modernbert_base": "answerdotai/ModernBERT-base", + "modernbert_large": "answerdotai/ModernBERT-large", +} + +EXTRACT_DIR = "./{}" + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", +) + + +def download_files(hf_model_name): + extract_dir = EXTRACT_DIR.format(FLAGS.preset) + if not os.path.exists(extract_dir): + os.makedirs(extract_dir) + + # Config. + config_path = os.path.join(extract_dir, "config.json") + response = requests.get( + f"https://huggingface.co/{hf_model_name}/raw/main/config.json" + ) + open(config_path, "wb").write(response.content) + + +def convert_model(hf_model): + extract_dir = EXTRACT_DIR.format(FLAGS.preset) + config_path = os.path.join(extract_dir, "config.json") + + # Build config. + cfg = {} + with open(config_path, "r") as pt_cfg_handler: + pt_cfg = json.load(pt_cfg_handler) + cfg["vocabulary_size"] = pt_cfg["vocab_size"] + cfg["num_layers"] = pt_cfg["num_hidden_layers"] + cfg["num_heads"] = pt_cfg["num_attention_heads"] + cfg["hidden_dim"] = pt_cfg["hidden_size"] + cfg["intermediate_dim"] = pt_cfg["intermediate_size"] + cfg["dropout"] = pt_cfg["embedding_dropout"] + cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"] + + return ModernBertBackbone(**cfg) + + +def convert_weights(keras_model, hf_model): + # Get `state_dict` from `hf_model`. + state_dict = hf_model.state_dict() + + keras_model.get_layer("token_embedding").set_weights( + [np.asarray(state_dict["embeddings.tok_embeddings.weight"])] + ) + + keras_model.get_layer("embeddings_layer_norm").set_weights( + [np.asarray(state_dict["embeddings.norm.weight"])] + ) + + for i in range(keras_model.num_layers): + keras_model.transformer_layers[i].attn.Wqkv.kernel.assign( + state_dict[f"layers.{i}.attn.Wqkv.weight"].T + ) + keras_model.transformer_layers[i].attn.Wo.kernel.assign( + state_dict[f"layers.{i}.attn.Wo.weight"] + ) + keras_model.transformer_layers[i].mlp_norm.gamma.assign( + state_dict[f"layers.{i}.mlp_norm.weight"] + ) + keras_model.transformer_layers[i].mlp.Wi.kernel.assign( + state_dict[f"layers.{i}.mlp.Wi.weight"].T + ) + keras_model.transformer_layers[i].mlp.Wo.kernel.assign( + state_dict[f"layers.{i}.mlp.Wo.weight"].T + ) + + keras_model.get_layer("final_layernorm").set_weights( + [np.asarray(state_dict["final_norm.weight"])] + ) + + +def main(_): + hf_model_name = PRESET_MAP[FLAGS.preset] + download_files(hf_model_name) + + hf_model = transformers.AutoModel.from_pretrained(hf_model_name) + hf_model.eval() + + print(f"🏃 Coverting {FLAGS.preset}") + keras_model = convert_model(hf_model) + print("✅ KerasHub model loaded.") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted.") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)