Skip to content

Commit 1844b46

Browse files
Fix the finetuning script's loss and metric config (#176)
* Fix the finetuning script * change softmax layer to return logits
1 parent 4ba8729 commit 1844b46

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

examples/bert/run_glue_finetuning.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import datasets
1919
import keras_tuner
2020
import tensorflow as tf
21-
import tensorflow_text as tftext
21+
import tensorflow_text as tf_text
2222
from absl import app
2323
from absl import flags
2424
from tensorflow import keras
@@ -81,20 +81,20 @@ def pack_inputs(
8181
):
8282
# In case inputs weren't truncated (as they should have been),
8383
# fall back to some ad-hoc truncation.
84-
trimmed_segments = tftext.RoundRobinTrimmer(
84+
trimmed_segments = tf_text.RoundRobinTrimmer(
8585
seq_length - len(inputs) - 1
8686
).trim(inputs)
8787
# Combine segments.
88-
segments_combined, segment_ids = tftext.combine_segments(
88+
segments_combined, segment_ids = tf_text.combine_segments(
8989
trimmed_segments,
9090
start_of_sequence_id=start_of_sequence_id,
9191
end_of_segment_id=end_of_segment_id,
9292
)
9393
# Pad to dense Tensors.
94-
input_word_ids, _ = tftext.pad_model_inputs(
94+
input_word_ids, _ = tf_text.pad_model_inputs(
9595
segments_combined, seq_length, pad_value=padding_id
9696
)
97-
input_type_ids, input_mask = tftext.pad_model_inputs(
97+
input_type_ids, input_mask = tf_text.pad_model_inputs(
9898
segment_ids, seq_length, pad_value=0
9999
)
100100
# Assemble nest of input tensors as expected by BERT model.
@@ -184,8 +184,8 @@ def build(self, hp):
184184
optimizer=keras.optimizers.Adam(
185185
learning_rate=hp.Choice("lr", [5e-5, 4e-5, 3e-5, 2e-5])
186186
),
187-
loss="sparse_categorical_crossentropy",
188-
metrics=["accuracy"],
187+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
188+
metrics=[keras.metrics.SparseCategoricalAccuracy()],
189189
)
190190
return finetuning_model
191191

@@ -197,7 +197,7 @@ def main(_):
197197
with open(FLAGS.vocab_file, "r") as vocab_file:
198198
for line in vocab_file:
199199
vocab.append(line.strip())
200-
tokenizer = tftext.BertTokenizer(
200+
tokenizer = tf_text.BertTokenizer(
201201
FLAGS.vocab_file,
202202
lower_case=FLAGS.do_lower_case,
203203
token_out_type=tf.int32,

0 commit comments

Comments
 (0)