Skip to content

Commit 9c58b0b

Browse files
authored
fix colab tpu training (#1050)
1 parent affe19b commit 9c58b0b

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

efficientdet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def __init__(self, name, _): # pylint: disable=super-init-not-called
376376

377377
def scalar(name, tensor, is_tpu=True):
378378
"""Stores a (name, Tensor) tuple in a custom collection."""
379-
logging.info('Adding scale summary {}'.format(Pair(name, tensor)))
379+
logging.info('Adding scalar summary {}'.format(Pair(name, tensor)))
380380
if is_tpu:
381381
tf.add_to_collection('scalar_summaries', Pair(name, tf.reduce_mean(tensor)))
382382
else:

efficientnetv2/datasets.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def _input_fn(self, batch_size, current_host, num_hosts):
531531
logging.info('use tfds: %s[%s]', self.cfg.tfds_name,
532532
self.cfg.splits[self.split]['tfds_split'])
533533
ds = tfds.load(
534-
self.cfg.tfds_name, split=self.cfg.splits[self.split]['tfds_split'])
534+
self.cfg.tfds_name, split=self.cfg.splits[self.split]['tfds_split'], try_gcs=self.cfg.try_gcs)
535535
ds = ds.shard(num_hosts, current_host)
536536
if self.is_training:
537537
if self.cache:
@@ -581,6 +581,21 @@ class FlowersInput(CIFAR10Input):
581581
)))
582582

583583

584+
class TFFlowersInput(CIFAR10Input):
585+
"""TFFlowers input from tfds gcs."""
586+
cfg = copy.deepcopy(CIFAR10Input.cfg)
587+
cfg.update(
588+
dict(
589+
num_classes=5,
590+
tfds_name='tf_flowers',
591+
try_gcs=True,
592+
splits=dict(
593+
train=dict(num_images=2569, tfds_split='train[:70%]'),
594+
minival=dict(num_images=1101, tfds_split='train[30%:]'),
595+
eval=dict(num_images=1101, tfds_split='train[30%:]'),
596+
)))
597+
598+
584599
class CarsInput(CIFAR10Input):
585600
"""Car input from tfds."""
586601
cfg = copy.deepcopy(CIFAR10Input.cfg)
@@ -620,6 +635,7 @@ def get_dataset_class(ds_name):
620635
'cifar10': CIFAR10Input,
621636
'cifar100': CIFAR100Input,
622637
'flowers': FlowersInput,
638+
'tfflowers': TFFlowersInput,
623639
'cars': CarsInput,
624640
}[ds_name]
625641

@@ -730,6 +746,11 @@ class FlowersFt(Cifar10Ft):
730746
cfg = copy.deepcopy(Cifar10Ft.cfg)
731747
cfg.data.override(dict(ds_name='flowers'))
732748

749+
@ds_register
750+
class TFFlowersFt(Cifar10Ft):
751+
"""Finetune tfflower configs."""
752+
cfg = copy.deepcopy(Cifar10Ft.cfg)
753+
cfg.data.override(dict(ds_name='tfflowers'))
733754

734755
@ds_register
735756
class CarsFt(Cifar10Ft):

efficientnetv2/hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def lookup(name, prefix='effnet:') -> Any:
288288
multiclass=None,
289289
num_classes=1000,
290290
tfds_name=None,
291+
try_gcs=False,
291292
tfds_split=None,
292293
splits=dict(
293294
train=dict(

efficientnetv2/main_tf2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ def main(_) -> None:
214214
log_dir=FLAGS.model_dir, update_freq=100)
215215
rstr_callback = utils.ReuableBackupAndRestore(backup_dir=FLAGS.model_dir)
216216

217+
def filter_callbacks(callbacks):
218+
if strategy == 'tpu' and not FLAGS.model_dir.startswith('gs://'):
219+
return list(filter(lambda callback: isinstance(callback, tf.keras.callbacks.ModelCheckpoint), callbacks))
220+
return callbacks
221+
217222
def get_dataset(training, image_size, config):
218223
"""A shared utility to get input dataset."""
219224
if training:
@@ -235,7 +240,7 @@ def get_dataset(training, image_size, config):
235240
validation_data=get_dataset(
236241
training=False, image_size=eval_size, config=config),
237242
validation_steps=num_eval_images // config.eval.batch_size,
238-
callbacks=[ckpt_callback, tb_callback, rstr_callback],
243+
callbacks=filter_callbacks([ckpt_callback, tb_callback, rstr_callback]),
239244
# don't log spam if running on tpus
240245
verbose=2 if strategy == 'tpu' else 1,
241246
)
@@ -245,7 +250,7 @@ def get_dataset(training, image_size, config):
245250
get_dataset(training=True, image_size=train_size, config=config),
246251
epochs=config.train.epochs,
247252
steps_per_epoch=steps_per_epoch,
248-
callbacks=[ckpt_callback, tb_callback, rstr_callback],
253+
callbacks=filter_callbacks([ckpt_callback, tb_callback, rstr_callback]),
249254
verbose=2 if strategy == 'tpu' else 1,
250255
)
251256
else:
@@ -274,7 +279,7 @@ def get_dataset(training, image_size, config):
274279
initial_epoch=start_epoch,
275280
epochs=end_epoch,
276281
steps_per_epoch=steps_per_epoch,
277-
callbacks=[ckpt_callback, tb_callback, rstr_callback],
282+
callbacks=filter_callbacks([ckpt_callback, tb_callback, rstr_callback]),
278283
verbose=2 if strategy == 'tpu' else 1,
279284
)
280285
elif FLAGS.mode == 'eval':
@@ -285,7 +290,7 @@ def get_dataset(training, image_size, config):
285290
get_dataset(training=False, image_size=eval_size, config=config),
286291
batch_size=config.eval.batch_size,
287292
steps=num_eval_images // config.eval.batch_size,
288-
callbacks=[tb_callback, rstr_callback],
293+
callbacks=filter_callbacks([tb_callback, rstr_callback]),
289294
verbose=2 if strategy == 'tpu' else 1,
290295
)
291296

efficientnetv2/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def __init__(self, name, _): # pylint: disable=super-init-not-called
336336

337337
def scalar(name, tensor, is_tpu=True):
338338
"""Stores a (name, Tensor) tuple in a custom collection."""
339-
logging.info('Adding scale summary %s', Pair(name, tensor))
339+
logging.info('Adding scalar summary %s', Pair(name, tensor))
340340
if is_tpu:
341341
tf.compat.v1.add_to_collection('scalar_summaries',
342342
Pair(name, tf.reduce_mean(tensor)))

0 commit comments

Comments
 (0)