@@ -56,8 +56,10 @@ def define_flags():
5656 help = 'GRPC URL of the eval master. Set to an appropriate value when '
5757 'running on CPU/GPU' )
5858 flags .DEFINE_string ('eval_name' , default = None , help = 'Eval job name' )
59- flags .DEFINE_enum ('strategy' , '' , ['tpu' , 'gpus' , '' ],
59+ flags .DEFINE_enum ('strategy' , '' , ['tpu' , 'gpus' , 'multi-gpus' , ' ' ],
6060 'Training: gpus for multi-gpu, if None, use TF default.' )
61+ flags .DEFINE_string ('worker' , default = None , help = 'Workers server address' )
62+ flags .DEFINE_integer ('worker_index' , default = 0 , help = 'Worker index' )
6163
6264 flags .DEFINE_integer (
6365 'num_cores' , default = 8 , help = 'Number of TPU cores for training' )
@@ -170,7 +172,8 @@ def main(_):
170172 if FLAGS .debug :
171173 tf .debugging .set_log_device_placement (True )
172174 logging .set_verbosity (logging .DEBUG )
173-
175+ tf .debugging .disable_traceback_filtering ()
176+
174177 if FLAGS .strategy == 'tpu' :
175178 tpu_cluster_resolver = tf .distribute .cluster_resolver .TPUClusterResolver (
176179 FLAGS .tpu , zone = FLAGS .tpu_zone , project = FLAGS .gcp_project )
@@ -193,6 +196,16 @@ def main(_):
193196 ds_strategy = tf .distribute .MirroredStrategy (
194197 cross_device_ops = cross_device_ops )
195198 logging .info ('All devices: %s' , gpus )
199+ elif FLAGS .strategy == 'multi-gpus' :
200+ import json
201+ tf_config = {
202+ 'cluster' : {
203+ 'worker' : FLAGS .worker .split (',' )
204+ },
205+ 'task' : {'type' : 'worker' , 'index' : FLAGS .worker_index }
206+ }
207+ os .environ ['TF_CONFIG' ] = json .dumps (tf_config )
208+ ds_strategy = tf .distribute .MultiWorkerMirroredStrategy ()
196209 else :
197210 if tf .config .list_physical_devices ('GPU' ):
198211 ds_strategy = tf .distribute .OneDeviceStrategy ('device:GPU:0' )
@@ -259,14 +272,19 @@ def get_dataset(is_training, config):
259272 ckpt_path ,
260273 config .moving_average_decay ,
261274 exclude_layers = ['class_net' , 'optimizer' , 'box_net' ])
275+
262276 init_experimental (config )
263277 if 'train' in FLAGS .mode :
264278 val_dataset = get_dataset (False , config ) if 'eval' in FLAGS .mode else None
279+ if FLAGS .strategy == 'multi-gpus' :
280+ initial_epoch = 0
281+ else :
282+ initial_epoch = model .optimizer .iterations .numpy () // steps_per_epoch
265283 model .fit (
266284 get_dataset (True , config ),
267285 epochs = config .num_epochs ,
268286 steps_per_epoch = steps_per_epoch ,
269- initial_epoch = model . optimizer . iterations . numpy () // steps_per_epoch ,
287+ initial_epoch = initial_epoch ,
270288 callbacks = train_lib .get_callbacks (config .as_dict (), val_dataset ),
271289 validation_data = val_dataset ,
272290 validation_steps = (FLAGS .eval_samples // FLAGS .batch_size ))
0 commit comments