diff --git a/examples/nlp/masked_language_modeling.py b/examples/nlp/masked_language_modeling.py
index 6701f0745e..2b5816cbae 100644
--- a/examples/nlp/masked_language_modeling.py
+++ b/examples/nlp/masked_language_modeling.py
@@ -1,11 +1,15 @@
"""
-Title: End-to-end Masked Language Modeling with BERT
-Author: [Ankur Singh](https://twitter.com/ankur310794)
-Date created: 2020/09/18
-Last modified: 2024/03/15
-Description: Implement a Masked Language Model (MLM) with BERT and fine-tune it on the IMDB Reviews dataset.
-Accelerator: GPU
-Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+# End-to-end Masked Language Modeling with BERT
+
+**Author:** [Ankur Singh](https://twitter.com/ankur310794)
+**Date created:** 2020/09/18
+**Last modified:** 2024/05/05
+**Description:** Implement a Masked Language Model (MLM) with BERT and fine-tune it on
+the IMDB Reviews dataset.
+**Accelerator:** GPU
+**Converted to Keras 3 by:** [Sitam Meur](https://github.com/sitamgithub-MSIT)
+**Converted to Keras 3 Backend-Agnostic by:** [Mrutyunjay
+Biswal](https://twitter.com/LearnStochastic)
"""
"""
@@ -32,35 +36,43 @@
train it with the masked language modeling task,
and then fine-tune this model on a sentiment classification task.
-We will use the Keras `TextVectorization` and `MultiHeadAttention` layers
+We will use the Keras `TextVectorization` and `MultiHeadAttention` layers, and
+`PositionEmbedding` from `keras-nlp`
to create a BERT Transformer-Encoder network architecture.
-Note: This example should be run with `tf-nightly`.
+Note: This is backend-agnostic, i.e. update the keras backend to "tensorflow", "torch",
+or "jax" as shown in the code, and it should work with no other code change.
"""
"""
## Setup
-
-Install `tf-nightly` via `pip install tf-nightly`.
"""
+# install keras 3.x and keras-nlp
+# !pip install --upgrade keras keras-nlp
+
import os
-os.environ["KERAS_BACKEND"] = "tensorflow"
-import keras_nlp
-import keras
-import tensorflow as tf
-from keras import layers
-from keras.layers import TextVectorization
-from dataclasses import dataclass
-import pandas as pd
-import numpy as np
-import glob
+# set backend ["tensorflow", "jax", "torch"]
+os.environ["KERAS_BACKEND"] = "jax"
+
import re
+import glob
+import numpy as np
+import pandas as pd
+from pathlib import Path
from pprint import pprint
+from dataclasses import dataclass
+
+import keras
+from keras import ops
+from keras import layers
+
+import keras_nlp
+import tensorflow as tf
"""
-## Set-up Configuration
+## Configuration
"""
@@ -74,19 +86,43 @@ class Config:
NUM_HEAD = 8 # used in bert model
FF_DIM = 128 # used in bert model
NUM_LAYERS = 1
+ NUM_EPOCHS = 1
+ STEPS_PER_EPOCH = 2
config = Config()
"""
-## Load the data
+## Download the Data: IMDB Movie Review Sentiment Classification
+Download the IMDB data and load into a Pandas DataFrame.
+"""
+
+fpath = keras.utils.get_file(
+ origin="https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
+)
+dirpath = Path(fpath).parent.absolute()
+_ = os.system(f"tar -xf {fpath} -C {dirpath}")
+
+"""
+The `aclImdb` folder contains a `train` and `test` subfolder:
+"""
+
+_ = os.system(f"ls {dirpath}/aclImdb")
+_ = os.system(f"ls {dirpath}/aclImdb/train")
+_ = os.system(f"ls {dirpath}/aclImdb/test")
-We will first download the IMDB data and load into a Pandas dataframe.
"""
+We are only interested in the `pos` and `neg` subfolders, so let's delete the rest:
+"""
+
+_ = os.system(f"rm -r {dirpath}/aclImdb/train/unsup")
+_ = os.system(f"rm -r {dirpath}/aclImdb/train/*.feat")
+_ = os.system(f"rm -r {dirpath}/aclImdb/train/*.txt")
+_ = os.system(f"rm -r {dirpath}/aclImdb/test/*.feat")
+_ = os.system(f"rm -r {dirpath}/aclImdb/test/*.txt")
-"""shell
-curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
-tar -xf aclImdb_v1.tar.gz
+"""
+Let's read the dataset from the text files to a DataFrame.
"""
@@ -100,9 +136,10 @@ def get_text_list_from_files(files):
def get_data_from_text_files(folder_name):
- pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
+
+ pos_files = glob.glob(f"{dirpath}/aclImdb/" + folder_name + "/pos/*.txt")
pos_texts = get_text_list_from_files(pos_files)
- neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
+ neg_files = glob.glob(f"{dirpath}/aclImdb/" + folder_name + "/neg/*.txt")
neg_texts = get_text_list_from_files(neg_files)
df = pd.DataFrame(
{
@@ -117,7 +154,8 @@ def get_data_from_text_files(folder_name):
train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")
-all_data = train_df.append(test_df)
+all_data = pd.concat([train_df, test_df], axis=0).reset_index(drop=True)
+assert len(all_data) != 0, f"{all_data} is empty"
"""
## Dataset preparation
@@ -125,7 +163,8 @@ def get_data_from_text_files(folder_name):
We will use the `TextVectorization` layer to vectorize the text into integer token ids.
It transforms a batch of strings into either
a sequence of token indices (one sample = 1D array of integer token indices, in order)
-or a dense representation (one sample = 1D array of float values encoding an unordered set of tokens).
+or a dense representation (one sample = 1D array of float values encoding an unordered
+set of tokens).
Below, we define 3 preprocessing functions.
@@ -156,7 +195,7 @@ def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):
Returns:
layers.Layer: Return TextVectorization Keras Layer
"""
- vectorize_layer = TextVectorization(
+ vectorize_layer = layers.TextVectorization(
max_tokens=vocab_size,
output_mode="int",
standardize=custom_standardization,
@@ -179,14 +218,15 @@ def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):
)
# Get mask token id for masked language model
-mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]
+mask_token_id = ops.convert_to_numpy(vectorize_layer(["[mask]"])[0][0])
def encode(texts):
encoded_texts = vectorize_layer(texts)
- return encoded_texts.numpy()
+ return ops.convert_to_numpy(encoded_texts)
+# todo: make this backend agnostic
def get_masked_input_and_labels(encoded_texts):
# 15% BERT masking
inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
@@ -213,7 +253,7 @@ def get_masked_input_and_labels(encoded_texts):
)
# Prepare sample_weights to pass to .fit() method
- sample_weights = np.ones(labels.shape)
+ sample_weights = np.ones(labels.shape, dtype="float32")
sample_weights[labels == -1] = 0
# y_labels would be same as encoded_texts i.e input tokens
@@ -261,104 +301,105 @@ def get_masked_input_and_labels(encoded_texts):
using the `MultiHeadAttention` layer.
It will take token ids as inputs (including masked tokens)
and it will predict the correct ids for the masked input tokens.
+
+We will use `keras.Model` and `Layer` to define sub-classes for
+BERT Encoder layer, and the MLM Model.
"""
+class BertEncoderLayer(layers.Layer):
+ def __init__(self, layer_num, **kwargs):
+ super().__init__(**kwargs)
+ self.layer_num = layer_num
-def bert_module(query, key, value, i):
- # Multi headed self-attention
- attention_output = layers.MultiHeadAttention(
- num_heads=config.NUM_HEAD,
- key_dim=config.EMBED_DIM // config.NUM_HEAD,
- name="encoder_{}_multiheadattention".format(i),
- )(query, key, value)
- attention_output = layers.Dropout(0.1, name="encoder_{}_att_dropout".format(i))(
- attention_output
- )
- attention_output = layers.LayerNormalization(
- epsilon=1e-6, name="encoder_{}_att_layernormalization".format(i)
- )(query + attention_output)
-
- # Feed-forward layer
- ffn = keras.Sequential(
- [
- layers.Dense(config.FF_DIM, activation="relu"),
- layers.Dense(config.EMBED_DIM),
- ],
- name="encoder_{}_ffn".format(i),
- )
- ffn_output = ffn(attention_output)
- ffn_output = layers.Dropout(0.1, name="encoder_{}_ffn_dropout".format(i))(
- ffn_output
- )
- sequence_output = layers.LayerNormalization(
- epsilon=1e-6, name="encoder_{}_ffn_layernormalization".format(i)
- )(attention_output + ffn_output)
- return sequence_output
+ self.multi_head_attention = layers.MultiHeadAttention(
+ num_heads=config.NUM_HEAD,
+ key_dim=config.EMBED_DIM // config.NUM_HEAD,
+ name=f"encoder_{self.layer_num}_multiheadattention"
+ )
+ self.multi_head_attention_dropout = layers.Dropout(
+ 0.1, name=f"encoder_{self.layer_num}_attn_dropout",
+ )
-loss_fn = keras.losses.SparseCategoricalCrossentropy(reduction=None)
-loss_tracker = keras.metrics.Mean(name="loss")
+ self.multi_head_attention_norm = layers.LayerNormalization(
+ epsilon=1e-6, name=f"encoder_{self.layer_num}_attn_layernorm"
+ )
+
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(config.FF_DIM, activation="relu"),
+ layers.Dense(config.EMBED_DIM)
+ ],
+ name=f"encoder_{self.layer_num}_ffn"
+ )
+
+ self.ffn_dropout = layers.Dropout(
+ 0.1, name=f"encoder_{self.layer_num}_ffn_dropout"
+ )
+
+ self.ffn_layernorm = layers.LayerNormalization(
+ epsilon=1e-6, name=f"encoder_{self.layer_num}_ffn_layernorm"
+ )
+ def call(self, inputs, training=False):
+ query, key, value = inputs
+ attn_output = self.multi_head_attention(query, key, value)
+ attn_output = self.multi_head_attention_dropout(attn_output)
+ attn_output = self.multi_head_attention_norm(query + attn_output)
+
+ ffn_output = self.ffn(attn_output)
+ ffn_output = self.ffn_dropout(ffn_output)
+
+ sequence_output = self.ffn_layernorm(attn_output + ffn_output)
+
+ return sequence_output
class MaskedLanguageModel(keras.Model):
- def train_step(self, inputs):
- if len(inputs) == 3:
- features, labels, sample_weight = inputs
- else:
- features, labels = inputs
- sample_weight = None
-
- with tf.GradientTape() as tape:
- predictions = self(features, training=True)
- loss = loss_fn(labels, predictions, sample_weight=sample_weight)
-
- # Compute gradients
- trainable_vars = self.trainable_variables
- gradients = tape.gradient(loss, trainable_vars)
-
- # Update weights
- self.optimizer.apply_gradients(zip(gradients, trainable_vars))
-
- # Compute our own metrics
- loss_tracker.update_state(loss, sample_weight=sample_weight)
-
- # Return a dict mapping metric names to current value
- return {"loss": loss_tracker.result()}
-
- @property
- def metrics(self):
- # We list our `Metric` objects here so that `reset_states()` can be
- # called automatically at the start of each epoch
- # or at the start of `evaluate()`.
- # If you don't implement this property, you have to call
- # `reset_states()` yourself at the time of your choosing.
- return [loss_tracker]
-
-
-def create_masked_language_bert_model():
- inputs = layers.Input((config.MAX_LEN,), dtype="int64")
-
- word_embeddings = layers.Embedding(
- config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
- )(inputs)
- position_embeddings = keras_nlp.layers.PositionEmbedding(
- sequence_length=config.MAX_LEN
- )(word_embeddings)
- embeddings = word_embeddings + position_embeddings
-
- encoder_output = embeddings
- for i in range(config.NUM_LAYERS):
- encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)
-
- mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(
- encoder_output
- )
- mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
- optimizer = keras.optimizers.Adam(learning_rate=config.LR)
- mlm_model.compile(optimizer=optimizer)
- return mlm_model
+ self.word_embeddings = layers.Embedding(
+ config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
+ )
+ self.position_embeddings = keras_nlp.layers.PositionEmbedding(
+ sequence_length=config.MAX_LEN
+ )
+
+ self.bert_encoder_layers = [
+ BertEncoderLayer(n_layer) for n_layer in range(config.NUM_LAYERS)
+ ]
+
+ self.mlm_output = layers.Dense(config.VOCAB_SIZE, activation="softmax", name="mlm_cls")
+
+ def call(self, inputs, training=False):
+ word_embeddings = self.word_embeddings(inputs)
+ position_embeddings = self.position_embeddings(word_embeddings)
+ encoder_output = word_embeddings + position_embeddings
+ for bert_encoder_layer in self.bert_encoder_layers:
+ encoder_output = bert_encoder_layer([encoder_output, encoder_output, encoder_output])
+
+ return self.mlm_output(encoder_output)
+
+ def get_config(self):
+ return super().get_config()
+
+# Reset Keras backend session
+keras.backend.clear_session()
+
+# Define model and compile
+optimizer = keras.optimizers.Adam(learning_rate=config.LR)
+loss_fn = keras.losses.SparseCategoricalCrossentropy(reduction=None)
+
+bert_masked_model = MaskedLanguageModel()
+bert_masked_model.compile(optimizer=optimizer, loss=loss_fn)
+
+# Show model summary
+bert_masked_model.summary()
+
+"""
+## Define a Callback to Generate Masked Token
+"""
id2token = dict(enumerate(vectorize_layer.get_vocabulary()))
token2id = {y: x for x, y in id2token.items()}
@@ -366,7 +407,7 @@ def create_masked_language_bert_model():
class MaskedTextGenerator(keras.callbacks.Callback):
def __init__(self, sample_tokens, top_k=5):
- self.sample_tokens = sample_tokens
+ self.sample_tokens = ops.convert_to_numpy(sample_tokens)
self.k = top_k
def decode(self, tokens):
@@ -378,20 +419,20 @@ def convert_ids_to_tokens(self, id):
def on_epoch_end(self, epoch, logs=None):
prediction = self.model.predict(self.sample_tokens)
- masked_index = np.where(self.sample_tokens == mask_token_id)
- masked_index = masked_index[1]
+ masked_index = ops.where(self.sample_tokens == mask_token_id)
+ masked_index = ops.convert_to_numpy(masked_index[1])[0]
mask_prediction = prediction[0][masked_index]
- top_indices = mask_prediction[0].argsort()[-self.k :][::-1]
- values = mask_prediction[0][top_indices]
+ top_indices = mask_prediction.argsort()[-self.k :][::-1]
+ values = mask_prediction[top_indices]
for i in range(len(top_indices)):
p = top_indices[i]
v = values[i]
- tokens = np.copy(sample_tokens[0])
- tokens[masked_index[0]] = p
+ tokens = np.copy(self.sample_tokens[0])
+ tokens[masked_index] = p
result = {
- "input_text": self.decode(sample_tokens[0].numpy()),
+ "input_text": self.decode(self.sample_tokens[0]),
"prediction": self.decode(tokens),
"probability": v,
"predicted mask token": self.convert_ids_to_tokens(p),
@@ -400,25 +441,28 @@ def on_epoch_end(self, epoch, logs=None):
sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])
-generator_callback = MaskedTextGenerator(sample_tokens.numpy())
-
-bert_masked_model = create_masked_language_bert_model()
-bert_masked_model.summary()
+generator_callback = MaskedTextGenerator(sample_tokens)
"""
## Train and Save
"""
-bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])
+bert_masked_model.fit(
+ mlm_ds,
+ epochs=config.NUM_EPOCHS,
+ steps_per_epoch=config.STEPS_PER_EPOCH,
+ callbacks=[generator_callback],
+)
bert_masked_model.save("bert_mlm_imdb.keras")
"""
## Fine-tune a sentiment classification model
-We will fine-tune our self-supervised model on a downstream task of sentiment classification.
-To do this, let's create a classifier by adding a pooling layer and a `Dense` layer on top of the
+We will fine-tune our self-supervised model on a downstream task of sentiment
+classification.
+To do this, let's create a classifier by adding a pooling layer and a `Dense` layer on
+top of the
pretrained BERT features.
-
"""
# Load pretrained bert model
@@ -434,7 +478,7 @@ def on_epoch_end(self, epoch, logs=None):
def create_classifier_bert_model():
- inputs = layers.Input((config.MAX_LEN,), dtype="int64")
+ inputs = layers.Input((config.MAX_LEN,), dtype="int32")
sequence_output = pretrained_bert_model(inputs)
pooled_output = layers.GlobalMaxPooling1D()(sequence_output)
hidden_layer = layers.Dense(64, activation="relu")(pooled_output)
@@ -465,7 +509,7 @@ def create_classifier_bert_model():
)
classifer_model.fit(
train_classifier_ds,
- epochs=5,
+ epochs=config.NUM_EPOCHS,
validation_data=test_classifier_ds,
)
@@ -481,7 +525,7 @@ def create_classifier_bert_model():
def get_end_to_end(model):
- inputs_string = keras.Input(shape=(1,), dtype="string")
+ inputs_string = layers.Input(shape=(1,), dtype="string")
indices = vectorize_layer(inputs_string)
outputs = model(indices)
end_to_end_model = keras.Model(inputs_string, outputs, name="end_to_end_model")