@@ -29,7 +29,7 @@ def parse_config():
2929 parser .add_argument ('--pretrained_model' , type = str , default = None , help = 'pretrained_model' )
3030 parser .add_argument ('--launcher' , choices = ['none' , 'pytorch' , 'slurm' ], default = 'none' )
3131 parser .add_argument ('--tcp_port' , type = int , default = 18888 , help = 'tcp port for distrbuted training' )
32- parser .add_argument ('--local_rank' , type = int , default = 0 , help = 'local rank for distributed training' )
32+ parser .add_argument ('--local_rank' , type = int , default = None , help = 'local rank for distributed training' )
3333 parser .add_argument ('--set' , dest = 'set_cfgs' , default = None , nargs = argparse .REMAINDER ,
3434 help = 'set extra config keys if needed' )
3535
@@ -145,6 +145,9 @@ def main():
145145 dist_test = False
146146 total_gpus = 1
147147 else :
148+ if args .local_rank is None :
149+ args .local_rank = int (os .environ .get ('LOCAL_RANK' , '0' ))
150+
148151 total_gpus , cfg .LOCAL_RANK = getattr (common_utils , 'init_dist_%s' % args .launcher )(
149152 args .tcp_port , args .local_rank , backend = 'nccl'
150153 )
0 commit comments