1818import datasets
1919import keras_tuner
2020import tensorflow as tf
21- import tensorflow_text as tftext
21+ import tensorflow_text as tf_text
2222from absl import app
2323from absl import flags
2424from tensorflow import keras
@@ -81,20 +81,20 @@ def pack_inputs(
8181):
8282 # In case inputs weren't truncated (as they should have been),
8383 # fall back to some ad-hoc truncation.
84- trimmed_segments = tftext .RoundRobinTrimmer (
84+ trimmed_segments = tf_text .RoundRobinTrimmer (
8585 seq_length - len (inputs ) - 1
8686 ).trim (inputs )
8787 # Combine segments.
88- segments_combined , segment_ids = tftext .combine_segments (
88+ segments_combined , segment_ids = tf_text .combine_segments (
8989 trimmed_segments ,
9090 start_of_sequence_id = start_of_sequence_id ,
9191 end_of_segment_id = end_of_segment_id ,
9292 )
9393 # Pad to dense Tensors.
94- input_word_ids , _ = tftext .pad_model_inputs (
94+ input_word_ids , _ = tf_text .pad_model_inputs (
9595 segments_combined , seq_length , pad_value = padding_id
9696 )
97- input_type_ids , input_mask = tftext .pad_model_inputs (
97+ input_type_ids , input_mask = tf_text .pad_model_inputs (
9898 segment_ids , seq_length , pad_value = 0
9999 )
100100 # Assemble nest of input tensors as expected by BERT model.
@@ -184,8 +184,8 @@ def build(self, hp):
184184 optimizer = keras .optimizers .Adam (
185185 learning_rate = hp .Choice ("lr" , [5e-5 , 4e-5 , 3e-5 , 2e-5 ])
186186 ),
187- loss = "sparse_categorical_crossentropy" ,
188- metrics = ["accuracy" ],
187+ loss = keras . losses . SparseCategoricalCrossentropy ( from_logits = True ) ,
188+ metrics = [keras . metrics . SparseCategoricalAccuracy () ],
189189 )
190190 return finetuning_model
191191
@@ -197,7 +197,7 @@ def main(_):
197197 with open (FLAGS .vocab_file , "r" ) as vocab_file :
198198 for line in vocab_file :
199199 vocab .append (line .strip ())
200- tokenizer = tftext .BertTokenizer (
200+ tokenizer = tf_text .BertTokenizer (
201201 FLAGS .vocab_file ,
202202 lower_case = FLAGS .do_lower_case ,
203203 token_out_type = tf .int32 ,
0 commit comments