@@ -33,7 +33,7 @@ def parse_config():
3333 parser .add_argument ('--sync_bn' , action = 'store_true' , default = False , help = 'whether to use sync bn' )
3434 parser .add_argument ('--fix_random_seed' , action = 'store_true' , default = False , help = '' )
3535 parser .add_argument ('--ckpt_save_interval' , type = int , default = 1 , help = 'number of training epochs' )
36- parser .add_argument ('--local_rank' , type = int , default = 0 , help = 'local rank for distributed training' )
36+ parser .add_argument ('--local_rank' , type = int , default = None , help = 'local rank for distributed training' )
3737 parser .add_argument ('--max_ckpt_save_num' , type = int , default = 30 , help = 'max number of saved checkpoint' )
3838 parser .add_argument ('--merge_all_iters_to_one_epoch' , action = 'store_true' , default = False , help = '' )
3939 parser .add_argument ('--set' , dest = 'set_cfgs' , default = None , nargs = argparse .REMAINDER ,
@@ -71,6 +71,9 @@ def main():
7171 dist_train = False
7272 total_gpus = 1
7373 else :
74+ if args .local_rank is None :
75+ args .local_rank = int (os .environ .get ('LOCAL_RANK' , '0' ))
76+
7477 total_gpus , cfg .LOCAL_RANK = getattr (common_utils , 'init_dist_%s' % args .launcher )(
7578 args .tcp_port , args .local_rank , backend = 'nccl'
7679 )
0 commit comments