@@ -97,6 +97,31 @@ def setup(self, group: "ProcessGroup|None"):
9797 def check_ranks_in_range (self , start , stop ):
9898 check_ranks_in_range (self .global_ranks , start , stop )
9999
100+ @classmethod
101+ def from_sizes_and_strides (cls , name : str , global_rank : int , * sizes_and_strides : tuple [int , int ]) -> typing .Self :
102+ start = global_rank
103+ rank = 0
104+ world_size = 1
105+ for i , (size , stride ) in enumerate (sizes_and_strides ):
106+ if i > 0 :
107+ Assert .multiple (stride , sizes_and_strides [i - 1 ][1 ])
108+ rank_ = global_rank // stride % size
109+ start -= rank_ * stride
110+ rank += world_size * rank_
111+ world_size *= size
112+ global_ranks = [start ]
113+ for size , stride in sizes_and_strides :
114+ if size == 1 :
115+ continue
116+ if len (global_ranks ) == 1 :
117+ global_ranks = range (start , start + size * stride , stride )
118+ elif isinstance (global_ranks , range ) and stride == global_ranks .stop - global_ranks .start :
119+ global_ranks = range (start , start + size * stride , global_ranks .step )
120+ else :
121+ global_ranks = [rank0 + rank1 for rank1 in range (0 , size * stride , stride ) for rank0 in global_ranks ]
122+ Assert .eq (len (global_ranks ), world_size )
123+ return DistributedDim (name = name , size = world_size , rank = rank , global_ranks = global_ranks )
124+
100125
101126def check_ranks_in_range (global_ranks , start , stop ):
102127 Assert .geq (min (global_ranks ), start )
@@ -112,6 +137,7 @@ class DistributedDimNames:
112137 sequence_data = "sequence_data"
113138 batch_data = "batch_data"
114139 tensor_and_sequence_data = "tensor_and_sequence_data"
140+ model_and_sequence_data = "model_and_sequence_data"
115141 tensor_and_data = "tensor_and_data"
116142
117143
@@ -300,88 +326,68 @@ def _validate(self) -> None:
300326 else :
301327 self .distributed_dims = {}
302328
303- data_stride = self .tensor_parallel * (self .pipeline_parallel if self .pipeline_first else 1 )
329+ tensor_stride = 1
330+ sequence_data_stride = self .tensor_parallel * (self .pipeline_parallel if self .pipeline_first else 1 )
331+ batch_data_stride = sequence_data_stride * self .sequence_data_parallel
304332 pipeline_stride = self .tensor_parallel * (1 if self .pipeline_first else self .data_parallel )
305333
306- self ._add_distributed_dim (
307- DistributedDim (
308- name = DistributedDimNames .world ,
309- size = self .world_size ,
310- rank = self .rank ,
311- global_ranks = range (self .world_size ),
312- )
334+ self ._add_distributed_dim_from_sizes_and_strides (
335+ DistributedDimNames .world ,
336+ (self .world_size , 1 ),
337+ )
338+ self ._add_distributed_dim_from_sizes_and_strides (
339+ DistributedDimNames .data ,
340+ (self .sequence_data_parallel , sequence_data_stride ),
341+ (self .batch_data_parallel , batch_data_stride ),
342+ )
343+ self ._add_distributed_dim_from_sizes_and_strides (
344+ DistributedDimNames .pipeline , (self .pipeline_parallel , pipeline_stride )
313345 )
314- self ._add_distributed_dim (
315- DistributedDim (
316- name = DistributedDimNames .data ,
317- size = self .data_parallel ,
318- rank = self .data_rank ,
319- global_ranks = self ._get_global_ranks (self .data_parallel , data_stride ),
320- )
346+ self ._add_distributed_dim_from_sizes_and_strides (
347+ DistributedDimNames .tensor , (self .tensor_parallel , tensor_stride )
321348 )
322- self ._add_distributed_dim (
323- DistributedDim (
324- name = DistributedDimNames .pipeline ,
325- size = self .pipeline_parallel ,
326- rank = self .pipeline_rank ,
327- global_ranks = self ._get_global_ranks (self .pipeline_parallel , pipeline_stride ),
328- )
349+ self ._add_distributed_dim_from_sizes_and_strides (
350+ DistributedDimNames .sequence_data ,
351+ (self .sequence_data_parallel , sequence_data_stride ),
329352 )
330- self ._add_distributed_dim (
331- DistributedDim (
332- name = DistributedDimNames .tensor ,
333- size = self .tensor_parallel ,
334- rank = self .tensor_rank ,
335- global_ranks = self ._get_global_ranks (self .tensor_parallel , 1 ),
336- )
353+ self ._add_distributed_dim_from_sizes_and_strides (
354+ DistributedDimNames .batch_data , (self .batch_data_parallel , batch_data_stride )
337355 )
338- self ._add_distributed_dim (
339- DistributedDim (
340- name = DistributedDimNames .sequence_data ,
341- size = self .sequence_data_parallel ,
342- rank = self .sequence_data_rank ,
343- global_ranks = self ._get_global_ranks (self .sequence_data_parallel , data_stride ),
344- )
356+ self ._add_distributed_dim_from_sizes_and_strides (
357+ DistributedDimNames .tensor_and_sequence_data ,
358+ (self .tensor_parallel , tensor_stride ),
359+ (self .sequence_data_parallel , sequence_data_stride ),
345360 )
346- self ._add_distributed_dim (
347- DistributedDim (
348- name = DistributedDimNames .batch_data ,
349- size = self .batch_data_parallel ,
350- rank = self .batch_data_rank ,
351- global_ranks = self ._get_global_ranks (
352- self .batch_data_parallel , data_stride * self .sequence_data_parallel
353- ),
354- )
361+ self ._add_distributed_dim_from_sizes_and_strides (
362+ DistributedDimNames .tensor_and_data ,
363+ (self .tensor_parallel , tensor_stride ),
364+ (self .sequence_data_parallel , sequence_data_stride ),
365+ (self .batch_data_parallel , batch_data_stride ),
355366 )
356- # Global ranks wrong with pipeline first, so we hide the dims as a safety check.
357- if not self .pipeline_first :
358- self ._add_distributed_dim (
359- DistributedDim (
360- name = DistributedDimNames .tensor_and_sequence_data ,
361- size = self .sequence_data_parallel * self .tensor_parallel ,
362- rank = self .tensor_rank + self .sequence_data_rank * self .tensor_parallel ,
363- global_ranks = self ._get_global_ranks (self .sequence_data_parallel * self .tensor_parallel , 1 ),
364- )
365- )
366- self ._add_distributed_dim (
367- DistributedDim (
368- name = DistributedDimNames .tensor_and_data ,
369- size = self .data_parallel * self .tensor_parallel ,
370- rank = self .tensor_rank + self .data_rank * self .tensor_parallel ,
371- global_ranks = self ._get_global_ranks (self .data_parallel * self .tensor_parallel , 1 ),
372- )
373- )
374367
375- super ()._validate ()
368+ self ._add_distributed_dim_from_sizes_and_strides (
369+ DistributedDimNames .model_and_sequence_data ,
370+ (self .tensor_parallel , tensor_stride ),
371+ (
372+ (self .pipeline_parallel , pipeline_stride )
373+ if self .pipeline_first
374+ else (self .sequence_data_parallel , sequence_data_stride )
375+ ),
376+ (
377+ (self .sequence_data_parallel , sequence_data_stride )
378+ if self .pipeline_first
379+ else (self .pipeline_parallel , pipeline_stride )
380+ ),
381+ )
376382
383+ super ()._validate ()
377384 if self .reference_config is not None :
378385 self .compare (self .reference_config , ValueError )
379386 Assert .in_range (self .rank , 0 , self .world_size )
380387 Assert .in_range (self .local_rank , 0 , self .local_world_size )
381388
382- def _get_global_ranks (self , size : int , stride : int ) -> range :
383- start = self .rank // (size * stride ) * size * stride + self .rank % stride
384- return range (start , start + size * stride , stride )
389+ def _add_distributed_dim_from_sizes_and_strides (self , name : str , * sizes_and_strides : tuple [int , int ]) -> None :
390+ self ._add_distributed_dim (DistributedDim .from_sizes_and_strides (name , self .rank , * sizes_and_strides ))
385391
386392 def _add_distributed_dim (self , distributed_dim : DistributedDim ) -> None :
387393 Assert .eq (distributed_dim .global_ranks [distributed_dim .rank ], self .rank , msg = distributed_dim )
0 commit comments