|
19 | 19 | 'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
|
20 | 20 | ]
|
21 | 21 |
|
| 22 | +import os |
| 23 | + |
22 | 24 |
|
23 | 25 | class Role:
|
24 | 26 | WORKER = 1
|
@@ -295,45 +297,62 @@ def generate_role(self):
|
295 | 297 | class PaddleCloudRoleMaker(RoleMakerBase):
|
296 | 298 | def __init__(self):
|
297 | 299 | super(PaddleCloudRoleMaker, self).__init__()
|
| 300 | + self._role_is_generated = False |
298 | 301 |
|
299 | 302 | def generate_role(self):
|
300 | 303 | if not self._role_is_generated:
|
301 | 304 | self.port = os.getenv("PADDLE_PORT", "6174")
|
302 | 305 | self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
|
303 | 306 | eplist = []
|
304 |
| - for ip in pserver_ips.split(","): |
305 |
| - eplist.append(':'.join([ip, port])) |
306 |
| - self.endpoints = ",".join(eplist) |
307 |
| - self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) |
308 |
| - self.current_endpoint = os.getenv("POD_IP", |
309 |
| - "localhost") + ":" + port |
310 |
| - self.role = os.getenv("TRAINING_ROLE", "TRAINER") |
311 |
| - self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) |
| 307 | + for ip in self.pserver_ips.split(","): |
| 308 | + eplist.append(':'.join([ip, self.port])) |
| 309 | + self.endpoints = ",".join(eplist) |
| 310 | + self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) |
| 311 | + self.current_endpoint = os.getenv("POD_IP", |
| 312 | + "localhost") + ":" + self.port |
| 313 | + self.role = os.getenv("TRAINING_ROLE", "TRAINER") |
| 314 | + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) |
312 | 315 | self.eplist = eplist
|
| 316 | + print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints) |
313 | 317 | self.endpoints = self.endpoints.split(",")
|
| 318 | + self._server_endpoints = self.endpoints |
314 | 319 | if self.role.upper() == "PSERVER":
|
315 |
| - self.current_id = self.endpoints.index(self.current_endpoint) |
| 320 | + self._current_id = self.endpoints.index(self.current_endpoint) |
| 321 | + self._role = Role.SERVER |
316 | 322 | else:
|
317 |
| - self.current_id = self.trainer_id |
| 323 | + self._current_id = self.trainer_id |
| 324 | + self._role = Role.WORKER |
318 | 325 | self._role_is_generated = True
|
319 | 326 |
|
320 |
| - def is_wokrer(self): |
| 327 | + def is_worker(self): |
| 328 | + if not self._role_is_generated: |
| 329 | + self.generate_role() |
321 | 330 | return self._role == Role.WORKER
|
322 | 331 |
|
323 | 332 | def is_server(self):
|
| 333 | + if not self._role_is_generated: |
| 334 | + self.generate_role() |
324 | 335 | return self._role == Role.SERVER
|
325 | 336 |
|
326 | 337 | def is_first_worker(self):
|
| 338 | + if not self._role_is_generated: |
| 339 | + self.generate_role() |
327 | 340 | return self._role == Role.WORKER and self._current_id == 0
|
328 | 341 |
|
329 | 342 | def worker_index(self):
|
| 343 | + if not self._role_is_generated: |
| 344 | + self.generate_role() |
330 | 345 | return self._current_id
|
331 | 346 |
|
332 | 347 | def server_index(self):
|
| 348 | + if not self._role_is_generated: |
| 349 | + self.generate_role() |
333 | 350 | return self._current_id
|
334 | 351 |
|
335 | 352 | def worker_num(self):
|
336 |
| - return self._worker_num |
| 353 | + if not self._role_is_generated: |
| 354 | + self.generate_role() |
| 355 | + return self._trainers |
337 | 356 |
|
338 | 357 |
|
339 | 358 | class UserDefinedRoleMaker(RoleMakerBase):
|
|
0 commit comments