@@ -428,59 +428,33 @@ def _init_comm(self):
428
428
429
429
# pp ring
430
430
if self .pp_degree > 1 :
431
- if self .schedule_mode == 'F-then-B' : # GPipe
432
- self ._collective_helper ._init_communicator (
433
- self ._startup_program ,
434
- self .current_endpoint ,
435
- self .pp_group_endpoints ,
436
- self .pp_rank ,
437
- self .pp_ring_id ,
438
- False ,
439
- global_ring_id = self .global_ring_id ,
440
- sync = False )
441
- # append_naive_sync(startup_block, self.startup_prog_sync_var,
442
- # self.global_ring_id)
431
+ for pair in self .pipeline_pair :
432
+ pair_key = pair [0 ] * 1000 + pair [1 ]
433
+ ring_id = self .pp_ring_map [pair_key ]
434
+ print ("pp pair:{}, ring_id: {}" .format (pair , ring_id ))
435
+ if self .pp_rank not in pair : continue
436
+ pp_group_endpoints = [
437
+ self .pp_group_endpoints [pair [0 ]],
438
+ self .pp_group_endpoints [pair [1 ]],
439
+ ]
440
+ if pair [0 ] < pair [1 ]:
441
+ start_ring_id = self .pp_ring_id + pair [1 ] - pair [0 ] - 1
442
+ else :
443
+ start_ring_id = self .pp_ring_id + 2 + pair [0 ] - pair [1 ] - 1
444
+ pp_rank = 0 if self .pp_rank == pair [0 ] else 1
443
445
self ._collective_helper ._init_communicator (
444
446
self ._startup_program ,
445
447
self .current_endpoint ,
446
- self . pp_group_endpoints ,
447
- self . pp_rank ,
448
- self . pp_ring_id + 2 ,
448
+ pp_group_endpoints ,
449
+ pp_rank ,
450
+ ring_id ,
449
451
False ,
450
452
global_ring_id = self .global_ring_id ,
451
453
sync = False )
452
454
# append_naive_sync(startup_block, self.startup_prog_sync_var,
453
455
# self.global_ring_id)
454
- else :
455
- assert self .schedule_mode == '1F1B'
456
- for pair in self .pipeline_pair :
457
- pair_key = pair [0 ] * 1000 + pair [1 ]
458
- ring_id = self .pp_ring_map [pair_key ]
459
- print ("pp pair:{}, ring_id: {}" .format (pair , ring_id ))
460
- if self .pp_rank not in pair : continue
461
- pp_group_endpoints = [
462
- self .pp_group_endpoints [pair [0 ]],
463
- self .pp_group_endpoints [pair [1 ]],
464
- ]
465
- if pair [0 ] < pair [1 ]:
466
- start_ring_id = self .pp_ring_id + pair [1 ] - pair [0 ] - 1
467
- else :
468
- start_ring_id = self .pp_ring_id + 2 + pair [0 ] - pair [
469
- 1 ] - 1
470
- pp_rank = 0 if self .pp_rank == pair [0 ] else 1
471
- self ._collective_helper ._init_communicator (
472
- self ._startup_program ,
473
- self .current_endpoint ,
474
- pp_group_endpoints ,
475
- pp_rank ,
476
- ring_id ,
477
- False ,
478
- global_ring_id = self .global_ring_id ,
479
- sync = False )
480
- # append_naive_sync(startup_block, self.startup_prog_sync_var,
481
- # self.global_ring_id)
482
-
483
- # TODO (JZ-LIANG) to unify this shit
456
+
457
+ # TODO (JZ-LIANG) to unify this shit
484
458
assert self .pp_rank_ == self .pp_rank , "pp rank for pp opt [{}], pp rank for sharding opt [{}]" .format (
485
459
self .pp_rank_ , self .pp_rank )
486
460
0 commit comments