@@ -214,6 +214,11 @@ def main(_) -> None:
214214 log_dir = FLAGS .model_dir , update_freq = 100 )
215215 rstr_callback = utils .ReuableBackupAndRestore (backup_dir = FLAGS .model_dir )
216216
217+ def filter_callbacks (callbacks ):
218+ if strategy == 'tpu' and not FLAGS .model_dir .startswith ('gs://' ):
219+ return list (filter (lambda callback : isinstance (callback , tf .keras .callbacks .ModelCheckpoint ), callbacks ))
220+ return callbacks
221+
217222 def get_dataset (training , image_size , config ):
218223 """A shared utility to get input dataset."""
219224 if training :
@@ -235,7 +240,7 @@ def get_dataset(training, image_size, config):
235240 validation_data = get_dataset (
236241 training = False , image_size = eval_size , config = config ),
237242 validation_steps = num_eval_images // config .eval .batch_size ,
238- callbacks = [ckpt_callback , tb_callback , rstr_callback ],
243+ callbacks = filter_callbacks ( [ckpt_callback , tb_callback , rstr_callback ]) ,
239244 # don't log spam if running on tpus
240245 verbose = 2 if strategy == 'tpu' else 1 ,
241246 )
@@ -245,7 +250,7 @@ def get_dataset(training, image_size, config):
245250 get_dataset (training = True , image_size = train_size , config = config ),
246251 epochs = config .train .epochs ,
247252 steps_per_epoch = steps_per_epoch ,
248- callbacks = [ckpt_callback , tb_callback , rstr_callback ],
253+ callbacks = filter_callbacks ( [ckpt_callback , tb_callback , rstr_callback ]) ,
249254 verbose = 2 if strategy == 'tpu' else 1 ,
250255 )
251256 else :
@@ -274,7 +279,7 @@ def get_dataset(training, image_size, config):
274279 initial_epoch = start_epoch ,
275280 epochs = end_epoch ,
276281 steps_per_epoch = steps_per_epoch ,
277- callbacks = [ckpt_callback , tb_callback , rstr_callback ],
282+ callbacks = filter_callbacks ( [ckpt_callback , tb_callback , rstr_callback ]) ,
278283 verbose = 2 if strategy == 'tpu' else 1 ,
279284 )
280285 elif FLAGS .mode == 'eval' :
@@ -285,7 +290,7 @@ def get_dataset(training, image_size, config):
285290 get_dataset (training = False , image_size = eval_size , config = config ),
286291 batch_size = config .eval .batch_size ,
287292 steps = num_eval_images // config .eval .batch_size ,
288- callbacks = [tb_callback , rstr_callback ],
293+ callbacks = filter_callbacks ( [tb_callback , rstr_callback ]) ,
289294 verbose = 2 if strategy == 'tpu' else 1 ,
290295 )
291296
0 commit comments