2626import torch .nn .functional as F
2727from deepspeed .utils import groups
2828from .mappings import drop_tokens , gather_tokens
29-
3029if TYPE_CHECKING :
3130 Base = Module [Tensor ]
3231else :
@@ -96,16 +95,19 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
9695class _AllToAll (torch .autograd .Function ):
9796
9897 @staticmethod
99- def forward (ctx : Any , group : dist .ProcessGroup , input : Tensor ) -> Tensor : # type: ignore
98+ def forward (ctx : Any , group : dist .ProcessGroup , input : Tensor , async_op = False ) -> Tensor : # type: ignore
10099 ctx .group = group
101100 input = input .contiguous ()
102101 output = torch .empty_like (input )
103- dist .all_to_all_single (output , input , group = group )
104- return output
102+ work = dist .all_to_all_single (output , input , group = group , async_op = async_op )
103+ if async_op :
104+ return output , work
105+ else :
106+ return output
105107
106108 @staticmethod
107109 def backward (ctx : Any , * grad_output : Tensor ) -> Tuple [None , Tensor ]:
108- return (None , _AllToAll .apply (ctx .group , * grad_output ))
110+ return (None , _AllToAll .apply (ctx .group , * grad_output ), None )
109111
110112
111113# einsum rewrites are on par or more performant
@@ -550,6 +552,7 @@ class MOELayer(Base):
550552 expert (torch.nn.Module):
551553 expert network
552554 """
555+ d2d_stream = torch .cuda .Stream ()
553556
554557 def __init__ (self ,
555558 gate : Module ,
@@ -572,6 +575,8 @@ def __init__(self,
572575 self .wall_clock_breakdown = False
573576
574577 self .use_tutel = use_tutel and TUTEL_INSTALLED and gate .k == 1
578+ self .enable_pipelie = True
579+ self .shard_num = 4
575580
576581 if self .use_tutel :
577582 logger .info ('Using Tutel optimizations.' )
@@ -586,8 +591,54 @@ def _set_ep_group(self, ep_group):
586591 self .ep_group = ep_group
587592 self .gate ._set_ep_group (ep_group )
588593
589- def forward (self , * input : Tensor , ** kwargs : Any ) -> Tensor :
594+ # During multi machine MOE training, alltoall is the communication between machines,
595+ # allgather is the communication within machines. They use different communication links,
596+ # so they can be executed in parallel
597+ # input shape (E,C,M),Shard input in C dim, first execute alltoall on the shard,
598+ # So the allgather of this shard and the alltoall of the next shard are executed in parallel
599+ # A E I M
600+ # A1 E1 I1 M1
601+ # A2 E2 I2 M2
602+ # A3 E3 I3 M3
603+ # A4 E4 I4 M4
604+ def pipeline_alltoall_with_allgather (self , input , shard_dim = 1 ) -> Tensor :
605+ if not self .enable_pipelie :
606+ input = _AllToAll .apply (self .ep_group , input )
607+ input = gather_tokens (input , dim = shard_dim )
608+ return input
609+
610+ assert self .shard_num > 0 , f"shard_num must be a positive number,but get is { self .shard_num } "
611+ input_chunks = list (input .chunk (self .shard_num , dim = shard_dim ))
612+ world_size = bwc_tensor_model_parallel_world_size (groups .mpu )
613+ dims = list (input .size ())
614+ dims [shard_dim ] = dims [shard_dim ] * world_size
615+ output = torch .empty (dims , device = input .device )
616+ input_gather_dim_len = input .shape [shard_dim ]
617+ have_gather_len = 0
618+ works = []
619+ for i in range (len (input_chunks )):
620+ input_chunks [i ], work = _AllToAll .apply (self .ep_group , input_chunks [i ], True )
621+ works .append (work )
622+
623+ current_stream = torch .cuda .current_stream ()
624+ for i in range (len (input_chunks )):
625+ works [i ].wait ()
626+ # we use dim 0 do allgather and chunk, so we can avoid unnecessary cat in gather_tokens
627+ gather_out = gather_tokens (input_chunks [i ], dim = 0 )
628+ gather_list = gather_out .chunk (world_size , dim = 0 )
629+ dim_len = gather_list [0 ].shape [shard_dim ]
630+ MOELayer .d2d_stream .wait_stream (current_stream )
631+
632+ for j in range (len (gather_list )):
633+ start = input_gather_dim_len * j + have_gather_len
634+ with torch .cuda .stream (MOELayer .d2d_stream ):
635+ torch .narrow (output , shard_dim , start , dim_len ).copy_ (gather_list [j ])
636+ have_gather_len += dim_len
637+
638+ current_stream .wait_stream (MOELayer .d2d_stream )
639+ return output
590640
641+ def forward (self , * input : Tensor , ** kwargs : Any ) -> Tensor :
591642 if self .wall_clock_breakdown :
592643 self .timers (MOE_TIMER ).start ()
593644
@@ -611,9 +662,6 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
611662 self .l_aux , combine_weights , dispatch_mask , self .exp_counts = self .gate (reshaped_input , input [1 ])
612663 dispatched_input = einsum ("sec,sm->ecm" , dispatch_mask .type_as (input [0 ]), reshaped_input )
613664
614- if self .wall_clock_breakdown :
615- self .timers (FIRST_ALLTOALL_TIMER ).start ()
616-
617665 tensor_model_world_size = bwc_tensor_model_parallel_world_size (groups .mpu )
618666 if tensor_model_world_size > 1 :
619667 # If the non-expert is tensor-parallel,
@@ -628,18 +676,17 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
628676 # an allgather to ensure correctness,
629677 dispatched_input = drop_tokens (dispatched_input , dim = 1 )
630678
631- dispatched_input = _AllToAll .apply (self .ep_group , dispatched_input )
632-
633679 if self .wall_clock_breakdown :
634- self .timers (FIRST_ALLTOALL_TIMER ).stop ()
635- self .time_falltoall = self .timers (FIRST_ALLTOALL_TIMER ).elapsed (reset = False )
680+ self .timers (FIRST_ALLTOALL_TIMER ).start ()
636681
637682 if tensor_model_world_size > 1 and groups ._get_expert_model_parallel_world_size () > 1 :
638- # if both expert and non-expert are tensor-parallel
639- # the dropped duplicate tokens need to be gathered on each
640- # tensor parallel rank again to ensure correctness
641- dispatched_input = gather_tokens (dispatched_input , dim = 1 )
683+ dispatched_input = self .pipeline_alltoall_with_allgather (dispatched_input )
684+ else :
685+ dispatched_input = _AllToAll .apply (self .ep_group , dispatched_input )
642686
687+ if self .wall_clock_breakdown :
688+ self .timers (FIRST_ALLTOALL_TIMER ).stop ()
689+ self .time_falltoall = self .timers (FIRST_ALLTOALL_TIMER ).elapsed (reset = False )
643690 # Re-shape after all-to-all: ecm -> gecm
644691 dispatched_input = dispatched_input .reshape (self .ep_size , self .num_local_experts , - 1 , d_model )
645692 expert_output = self .experts (dispatched_input )
@@ -654,18 +701,12 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
654701 if self .wall_clock_breakdown :
655702 self .timers (SECOND_ALLTOALL_TIMER ).start ()
656703
657- expert_output = _AllToAll . apply ( self .ep_group , expert_output )
704+ expert_output = self .pipeline_alltoall_with_allgather ( expert_output )
658705
659706 if self .wall_clock_breakdown :
660707 self .timers (SECOND_ALLTOALL_TIMER ).stop ()
661708 self .time_salltoall = self .timers (SECOND_ALLTOALL_TIMER ).elapsed (reset = False )
662709
663- if tensor_model_world_size > 1 :
664- # the dropped duplicate tokens need to be gathered on each
665- # tensor parallel rank again for the tensor-parallel
666- # non-expert of the next layer.
667- expert_output = gather_tokens (expert_output , dim = 1 )
668-
669710 if self .use_tutel :
670711 combined_output = self ._tutel_dispatcher .decode (expert_output .view (E * C , M ))
671712 else :
0 commit comments