@@ -131,7 +131,40 @@ def __init__(self, train_func, optimizer, param_path=None, place=None):
131131 # load params from param_path into scope
132132 io .load_persistables (exe , dirname = param_path )
133133
134+ def _transpile_nccl2_dist (self ):
135+ # PADDLE_TRAINER_IPS
136+ if "PADDLE_TRAINER_IPS" not in os .environ :
137+ self .nccl_id_var = None
138+ else :
139+ self .trainer_id = int (os .getenv ("PADDLE_TRAINER_ID" ))
140+ port = os .getenv ("PADDLE_PSERVER_PORT" )
141+ worker_ips = os .getenv ("PADDLE_TRAINER_IPS" )
142+ worker_endpoints = []
143+ for ip in worker_ips .split ("," ):
144+ worker_endpoints .append (':' .join ([ip , port ]))
145+ self .num_trainers = len (worker_endpoints )
146+ current_endpoint = os .getenv ("POD_IP" ) + ":" + port
147+ worker_endpoints .remove (current_endpoint )
148+ # TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
149+ # in ParallelExecutor to start
150+ # distributed training using NCCL2
151+ self .nccl_id_var = self .startup_program .global_block ().create_var (
152+ name = "NCCLID" , persistable = True , type = core .VarDesc .VarType .RAW )
153+ self .startup_program .global_block ().append_op (
154+ type = "gen_nccl_id" ,
155+ inputs = {},
156+ outputs = {"NCCLID" : self .nccl_id_var },
157+ attrs = {
158+ "endpoint" : current_endpoint ,
159+ "endpoint_list" : worker_endpoints ,
160+ "trainer_id" : self .trainer_id
161+ })
162+
134163 def _dist_transpile_if_necessary (self , optimize_ops , params_grads ):
164+ self ._transpile_nccl2_dist ()
165+ if self .nccl_id_var != None :
166+ return
167+
135168 if "PADDLE_TRAINING_ROLE" not in os .environ :
136169 return
137170
0 commit comments