Skip to content

Commit f1239cf

Browse files
committed
add multi node training
1 parent 861975e commit f1239cf

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

efficientdet/tf2/train.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ def define_flags():
5656
help='GRPC URL of the eval master. Set to an appropriate value when '
5757
'running on CPU/GPU')
5858
flags.DEFINE_string('eval_name', default=None, help='Eval job name')
59-
flags.DEFINE_enum('strategy', '', ['tpu', 'gpus', ''],
59+
flags.DEFINE_enum('strategy', '', ['tpu', 'gpus', 'multi-gpus', ''],
6060
'Training: gpus for multi-gpu, if None, use TF default.')
61+
flags.DEFINE_string('worker', default=None, help='Workers server address')
62+
flags.DEFINE_integer('worker_index', default=0, help='Worker index')
6163

6264
flags.DEFINE_integer(
6365
'num_cores', default=8, help='Number of TPU cores for training')
@@ -170,7 +172,8 @@ def main(_):
170172
if FLAGS.debug:
171173
tf.debugging.set_log_device_placement(True)
172174
logging.set_verbosity(logging.DEBUG)
173-
175+
tf.debugging.disable_traceback_filtering()
176+
174177
if FLAGS.strategy == 'tpu':
175178
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
176179
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
@@ -193,6 +196,16 @@ def main(_):
193196
ds_strategy = tf.distribute.MirroredStrategy(
194197
cross_device_ops=cross_device_ops)
195198
logging.info('All devices: %s', gpus)
199+
elif FLAGS.strategy == 'multi-gpus':
200+
import json
201+
tf_config = {
202+
'cluster': {
203+
'worker': FLAGS.worker.split(',')
204+
},
205+
'task': {'type': 'worker', 'index': FLAGS.worker_index}
206+
}
207+
os.environ['TF_CONFIG'] = json.dumps(tf_config)
208+
ds_strategy = tf.distribute.MultiWorkerMirroredStrategy()
196209
else:
197210
if tf.config.list_physical_devices('GPU'):
198211
ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
@@ -259,14 +272,19 @@ def get_dataset(is_training, config):
259272
ckpt_path,
260273
config.moving_average_decay,
261274
exclude_layers=['class_net', 'optimizer', 'box_net'])
275+
262276
init_experimental(config)
263277
if 'train' in FLAGS.mode:
264278
val_dataset = get_dataset(False, config) if 'eval' in FLAGS.mode else None
279+
if FLAGS.strategy == 'multi-gpus':
280+
initial_epoch = 0
281+
else:
282+
initial_epoch = model.optimizer.iterations.numpy() // steps_per_epoch
265283
model.fit(
266284
get_dataset(True, config),
267285
epochs=config.num_epochs,
268286
steps_per_epoch=steps_per_epoch,
269-
initial_epoch=model.optimizer.iterations.numpy() // steps_per_epoch,
287+
initial_epoch=initial_epoch,
270288
callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset),
271289
validation_data=val_dataset,
272290
validation_steps=(FLAGS.eval_samples // FLAGS.batch_size))

efficientdet/tf2/tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@
369369
" !unzip annotations_trainval2017.zip\n",
370370
"\n",
371371
" !mkdir tfrecord\n",
372-
" !PYTHONPATH=\".:$PYTHONPATH\" python dataset/create_coco_tfrecord.py \\\n",
372+
" !python -m dataset.create_coco_tfrecord \\\n",
373373
" --image_dir=val2017 \\\n",
374374
" --caption_annotations_file=annotations/captions_val2017.json \\\n",
375375
" --output_file_prefix=tfrecord/val \\\n",

0 commit comments

Comments
 (0)