@@ -292,18 +292,16 @@ def has_cp(self):
292292 return self .cp_size > 1
293293
294294 def prev_cp_rank (self ):
295- p = self .rank - self .tp_size
296- if p // (self .tp_size * self .cp_size ) < self .rank // (self .tp_size *
297- self .cp_size ):
298- return p + self .tp_size * self .cp_size
299- return p
295+ # cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group
296+ if self .cp_rank == 0 :
297+ return self .rank + self .cp_size - 1
298+ return self .rank - 1
300299
301300 def next_cp_rank (self ):
302- p = self .rank + self .tp_size
303- if p // (self .tp_size * self .cp_size ) > self .rank // (self .tp_size *
304- self .cp_size ):
305- return p - self .tp_size * self .cp_size
306- return p
301+ # cp ranks are consecutive, so next is rank + 1 with wraparound within cp group
302+ if self .cp_rank == self .cp_size - 1 :
303+ return self .rank - self .cp_size + 1
304+ return self .rank + 1
307305
308306 def has_moe_cluster (self ):
309307 return self .moe_cluster_size > 1
@@ -378,17 +376,17 @@ class Mapping(MappingBase):
378376
379377 A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
380378
381- 2 tp groups:
379+ 4 cp groups:
382380
383- - [0, 1, 2, 3]
384- - [4, 5, 6, 7]
381+ - [0, 1]
382+ - [2, 3]
383+ - [4, 5]
384+ - [6, 7]
385385
386- 4 cp groups:
386+ 2 tp groups:
387387
388- - [0, 4]
389- - [1, 5]
390- - [2, 6]
391- - [3, 7]
388+ - [0, 2, 4, 6]
389+ - [1, 3, 5, 7]
392390
393391 A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
394392
@@ -437,23 +435,23 @@ class Mapping(MappingBase):
437435
438436 2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
439437
440- 4 tp groups:
438+ 4 cp groups:
441439 - [0, 1]
442440 - [2, 3]
443441 - [4, 5]
444442 - [6, 7]
445443
444+ 4 tp groups:
445+ - [0, 2]
446+ - [1, 3]
447+ - [4, 6]
448+ - [5, 7]
449+
446450 4 pp groups:
447451 - [0, 4]
448452 - [1, 5]
449453 - [2, 6]
450454 - [3, 7]
451-
452- 4 cp groups:
453- - [0, 2]
454- - [1, 3]
455- - [4, 6]
456- - [5, 7]
457455 """
458456
459457 def __new__ (cls , * args , ** kwargs ):
@@ -551,23 +549,23 @@ def __init__(self, *args, **kwargs):
551549
552550 @property
553551 def tp_rank (self ) -> int :
554- return self .rank % self .tp_size
552+ return self .rank % ( self .tp_size * self . cp_size ) // self . cp_size
555553
556554 @property
557555 def pp_rank (self ) -> int :
558556 return self .rank // (self .tp_size * self .cp_size )
559557
560558 @property
561559 def cp_rank (self ) -> int :
562- return self .rank % ( self .tp_size * self . cp_size ) // self . tp_size
560+ return self .rank % self .cp_size
563561
564562 @property
565563 def tp_group (self ) -> List [int ]:
566564 return self .tp_groups [self .pp_rank * self .cp_size + self .cp_rank ]
567565
568566 @property
569567 def pp_group (self ) -> List [int ]:
570- return self .pp_groups [self .cp_rank * self .tp_size + self .tp_rank ]
568+ return self .pp_groups [self .tp_rank * self .cp_size + self .cp_rank ]
571569
572570 @property
573571 def cp_group (self ) -> List [int ]:
@@ -598,20 +596,20 @@ def _init_parallel_groups(self):
598596 ranks = range (i , self .world_size , self .tp_size * self .cp_size )
599597 self .pp_groups .append (list (ranks ))
600598
601- # init cp group
599+ # init cp group (consecutive ranks within each tp slice)
602600 for i in range (self .pp_size ):
603601 for j in range (self .tp_size ):
604- ranks = range (i * self . tp_size * self . cp_size + j ,
605- ( i + 1 ) * self .tp_size * self .cp_size + j ,
606- self .tp_size )
602+ ranks = range (
603+ i * self .tp_size * self .cp_size + j * self . cp_size ,
604+ i * self . tp_size * self . cp_size + ( j + 1 ) * self .cp_size )
607605 self .cp_groups .append (list (ranks ))
608606
609- # init tp group
607+ # init tp group (interleaved ranks with stride of cp_size)
610608 for i in range (self .pp_size ):
611609 for j in range (self .cp_size ):
612- ranks = range (
613- i * self .tp_size * self .cp_size + j * self . tp_size ,
614- i * self . tp_size * self . cp_size + ( j + 1 ) * self .tp_size )
610+ ranks = range (i * self . tp_size * self . cp_size + j ,
611+ ( i + 1 ) * self .tp_size * self .cp_size + j ,
612+ self .cp_size )
615613 self .tp_groups .append (list (ranks ))
616614
617615 # init moe tp group
0 commit comments