Skip to content

Commit 7be50f9

Browse files
author
lilong12
authored
update, test=develop (#33588)
1 parent 172f271 commit 7be50f9

File tree

1 file changed

+19
-45
lines changed

1 file changed

+19
-45
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -428,59 +428,33 @@ def _init_comm(self):
428428

429429
# pp ring
430430
if self.pp_degree > 1:
431-
if self.schedule_mode == 'F-then-B': # GPipe
432-
self._collective_helper._init_communicator(
433-
self._startup_program,
434-
self.current_endpoint,
435-
self.pp_group_endpoints,
436-
self.pp_rank,
437-
self.pp_ring_id,
438-
False,
439-
global_ring_id=self.global_ring_id,
440-
sync=False)
441-
# append_naive_sync(startup_block, self.startup_prog_sync_var,
442-
# self.global_ring_id)
431+
for pair in self.pipeline_pair:
432+
pair_key = pair[0] * 1000 + pair[1]
433+
ring_id = self.pp_ring_map[pair_key]
434+
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
435+
if self.pp_rank not in pair: continue
436+
pp_group_endpoints = [
437+
self.pp_group_endpoints[pair[0]],
438+
self.pp_group_endpoints[pair[1]],
439+
]
440+
if pair[0] < pair[1]:
441+
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
442+
else:
443+
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
444+
pp_rank = 0 if self.pp_rank == pair[0] else 1
443445
self._collective_helper._init_communicator(
444446
self._startup_program,
445447
self.current_endpoint,
446-
self.pp_group_endpoints,
447-
self.pp_rank,
448-
self.pp_ring_id + 2,
448+
pp_group_endpoints,
449+
pp_rank,
450+
ring_id,
449451
False,
450452
global_ring_id=self.global_ring_id,
451453
sync=False)
452454
# append_naive_sync(startup_block, self.startup_prog_sync_var,
453455
# self.global_ring_id)
454-
else:
455-
assert self.schedule_mode == '1F1B'
456-
for pair in self.pipeline_pair:
457-
pair_key = pair[0] * 1000 + pair[1]
458-
ring_id = self.pp_ring_map[pair_key]
459-
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
460-
if self.pp_rank not in pair: continue
461-
pp_group_endpoints = [
462-
self.pp_group_endpoints[pair[0]],
463-
self.pp_group_endpoints[pair[1]],
464-
]
465-
if pair[0] < pair[1]:
466-
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
467-
else:
468-
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
469-
1] - 1
470-
pp_rank = 0 if self.pp_rank == pair[0] else 1
471-
self._collective_helper._init_communicator(
472-
self._startup_program,
473-
self.current_endpoint,
474-
pp_group_endpoints,
475-
pp_rank,
476-
ring_id,
477-
False,
478-
global_ring_id=self.global_ring_id,
479-
sync=False)
480-
# append_naive_sync(startup_block, self.startup_prog_sync_var,
481-
# self.global_ring_id)
482-
483-
# TODO (JZ-LIANG) to unify this shit
456+
457+
# TODO (JZ-LIANG) to unify this shit
484458
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
485459
self.pp_rank_, self.pp_rank)
486460

0 commit comments

Comments
 (0)