Skip to content

Commit 68eb2f7

Browse files
Added KerasTuner Hyper-Parameter Search for the BERT fine-tuning script. (#143)
* added keras-tuner * added keras-tuner and using the best model * changed build_model and a few other changes * ran format.sh and lint.sh * changed the function to a class and added 4e-5 * added comments and final checks * ran format.sh and lint.sh * made changes according to the reviews * ran format.sh and lint.sh * resolved comments, going with val_loss for now * ran format.sh and lint.sh * changed constants to flags, and ran format checks * removed retraining of the best model found * Remove no longer relevant comment Co-authored-by: Matt Watson <[email protected]>
1 parent fd91d24 commit 68eb2f7

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

examples/bert/run_glue_finetuning.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717

1818
import datasets
19+
import keras_tuner
1920
import tensorflow as tf
2021
import tensorflow_text as tftext
2122
from absl import app
@@ -68,8 +69,6 @@
6869

6970
flags.DEFINE_integer("epochs", 3, "The number of training epochs.")
7071

71-
flags.DEFINE_float("learning_rate", 2e-5, "The initial learning rate for Adam.")
72-
7372
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
7473

7574

@@ -167,9 +166,32 @@ def call(self, inputs):
167166
return self._logit_layer(outputs)
168167

169168

169+
class BertHyperModel(keras_tuner.HyperModel):
170+
"""Creates a hypermodel to help with the search space for finetuning."""
171+
172+
def __init__(self, bert_config):
173+
self.bert_config = bert_config
174+
175+
def build(self, hp):
176+
model = keras.models.load_model(FLAGS.saved_model_input, compile=False)
177+
bert_config = self.bert_config
178+
finetuning_model = BertClassificationFinetuner(
179+
bert_model=model,
180+
hidden_size=bert_config["hidden_size"],
181+
num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2,
182+
)
183+
finetuning_model.compile(
184+
optimizer=keras.optimizers.Adam(
185+
learning_rate=hp.Choice("lr", [5e-5, 4e-5, 3e-5, 2e-5])
186+
),
187+
loss="sparse_categorical_crossentropy",
188+
metrics=["accuracy"],
189+
)
190+
return finetuning_model
191+
192+
170193
def main(_):
171194
print(f"Reading input model from {FLAGS.saved_model_input}")
172-
model = keras.models.load_model(FLAGS.saved_model_input)
173195

174196
vocab = []
175197
with open(FLAGS.vocab_file, "r") as vocab_file:
@@ -200,6 +222,7 @@ def preprocess_data(inputs, labels):
200222

201223
# Read and preprocess GLUE task data.
202224
train_ds, test_ds, validation_ds = load_data(FLAGS.task_name)
225+
203226
train_ds = train_ds.batch(FLAGS.batch_size).map(
204227
preprocess_data, num_parallel_calls=tf.data.AUTOTUNE
205228
)
@@ -210,18 +233,27 @@ def preprocess_data(inputs, labels):
210233
preprocess_data, num_parallel_calls=tf.data.AUTOTUNE
211234
)
212235

213-
finetuning_model = BertClassificationFinetuner(
214-
bert_model=model,
215-
hidden_size=bert_config["hidden_size"],
216-
num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2,
217-
)
218-
finetuning_model.compile(
219-
optimizer=keras.optimizers.Adam(learning_rate=FLAGS.learning_rate),
220-
loss="sparse_categorical_crossentropy",
221-
metrics=["accuracy"],
236+
# Create a hypermodel object for a RandomSearch.
237+
hypermodel = BertHyperModel(bert_config)
238+
239+
# Initialize the random search over the 4 learning rate parameters, for 4
240+
# trials and 3 epochs for each trial.
241+
tuner = keras_tuner.RandomSearch(
242+
hypermodel=hypermodel,
243+
objective=keras_tuner.Objective("val_loss", direction="min"),
244+
max_trials=4,
245+
overwrite=True,
246+
project_name="hyperparameter_tuner_results",
222247
)
223-
finetuning_model.fit(
224-
train_ds, epochs=FLAGS.epochs, validation_data=validation_ds
248+
249+
tuner.search(train_ds, epochs=FLAGS.epochs, validation_data=validation_ds)
250+
251+
# Extract the best hyperparameters after the search.
252+
best_hp = tuner.get_best_hyperparameters()[0]
253+
finetuning_model = tuner.get_best_models()[0]
254+
255+
print(
256+
f"The best hyperparameters found are:\nLearning Rate: {best_hp['lr']}"
225257
)
226258

227259
if FLAGS.do_evaluation:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"datasets", # For GLUE in BERT example.
5454
"nltk",
5555
"wikiextractor",
56+
"keras-tuner",
5657
],
5758
},
5859
classifiers=[

0 commit comments

Comments
 (0)