Skip to content

Commit 7b17bc1

Browse files
authored
Merge pull request #10688 from typhoonzero/add_nccl2_support_for_trainer
Support nccl2 dist train in trainer
2 parents 1f8243b + a41a94f commit 7b17bc1

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

python/paddle/fluid/trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,40 @@ def __init__(self, train_func, optimizer, param_path=None, place=None):
131131
# load params from param_path into scope
132132
io.load_persistables(exe, dirname=param_path)
133133

134+
def _transpile_nccl2_dist(self):
135+
# PADDLE_TRAINER_IPS
136+
if "PADDLE_TRAINER_IPS" not in os.environ:
137+
self.nccl_id_var = None
138+
else:
139+
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
140+
port = os.getenv("PADDLE_PSERVER_PORT")
141+
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
142+
worker_endpoints = []
143+
for ip in worker_ips.split(","):
144+
worker_endpoints.append(':'.join([ip, port]))
145+
self.num_trainers = len(worker_endpoints)
146+
current_endpoint = os.getenv("POD_IP") + ":" + port
147+
worker_endpoints.remove(current_endpoint)
148+
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
149+
# in ParallelExecutor to start
150+
# distributed training using NCCL2
151+
self.nccl_id_var = self.startup_program.global_block().create_var(
152+
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
153+
self.startup_program.global_block().append_op(
154+
type="gen_nccl_id",
155+
inputs={},
156+
outputs={"NCCLID": self.nccl_id_var},
157+
attrs={
158+
"endpoint": current_endpoint,
159+
"endpoint_list": worker_endpoints,
160+
"trainer_id": self.trainer_id
161+
})
162+
134163
def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
164+
self._transpile_nccl2_dist()
165+
if self.nccl_id_var != None:
166+
return
167+
135168
if "PADDLE_TRAINING_ROLE" not in os.environ:
136169
return
137170

0 commit comments

Comments
 (0)