Skip to content

Commit 0fad63a

Browse files
seiriosPlusguru4elephant
authored andcommitted
fix paddle cloud role maker bug (#18269) (#18311)
* fix paddle cloud role maker bug
1 parent f643260 commit 0fad63a

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

python/paddle/fluid/incubate/fleet/base/role_maker.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
2020
]
2121

22+
import os
23+
2224

2325
class Role:
2426
WORKER = 1
@@ -295,45 +297,62 @@ def generate_role(self):
295297
class PaddleCloudRoleMaker(RoleMakerBase):
296298
def __init__(self):
297299
super(PaddleCloudRoleMaker, self).__init__()
300+
self._role_is_generated = False
298301

299302
def generate_role(self):
300303
if not self._role_is_generated:
301304
self.port = os.getenv("PADDLE_PORT", "6174")
302305
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
303306
eplist = []
304-
for ip in pserver_ips.split(","):
305-
eplist.append(':'.join([ip, port]))
306-
self.endpoints = ",".join(eplist)
307-
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
308-
self.current_endpoint = os.getenv("POD_IP",
309-
"localhost") + ":" + port
310-
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
311-
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
307+
for ip in self.pserver_ips.split(","):
308+
eplist.append(':'.join([ip, self.port]))
309+
self.endpoints = ",".join(eplist)
310+
self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
311+
self.current_endpoint = os.getenv("POD_IP",
312+
"localhost") + ":" + self.port
313+
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
314+
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
312315
self.eplist = eplist
316+
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
313317
self.endpoints = self.endpoints.split(",")
318+
self._server_endpoints = self.endpoints
314319
if self.role.upper() == "PSERVER":
315-
self.current_id = self.endpoints.index(self.current_endpoint)
320+
self._current_id = self.endpoints.index(self.current_endpoint)
321+
self._role = Role.SERVER
316322
else:
317-
self.current_id = self.trainer_id
323+
self._current_id = self.trainer_id
324+
self._role = Role.WORKER
318325
self._role_is_generated = True
319326

320-
def is_wokrer(self):
327+
def is_worker(self):
328+
if not self._role_is_generated:
329+
self.generate_role()
321330
return self._role == Role.WORKER
322331

323332
def is_server(self):
333+
if not self._role_is_generated:
334+
self.generate_role()
324335
return self._role == Role.SERVER
325336

326337
def is_first_worker(self):
338+
if not self._role_is_generated:
339+
self.generate_role()
327340
return self._role == Role.WORKER and self._current_id == 0
328341

329342
def worker_index(self):
343+
if not self._role_is_generated:
344+
self.generate_role()
330345
return self._current_id
331346

332347
def server_index(self):
348+
if not self._role_is_generated:
349+
self.generate_role()
333350
return self._current_id
334351

335352
def worker_num(self):
336-
return self._worker_num
353+
if not self._role_is_generated:
354+
self.generate_role()
355+
return self._trainers
337356

338357

339358
class UserDefinedRoleMaker(RoleMakerBase):

0 commit comments

Comments
 (0)