@@ -1021,22 +1021,23 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10211021 x = _wait_tensor (x )
10221022 return x
10231023
1024- def _all_to_all_double (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
1024+ def _all_to_all_dim_exchange (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
10251025 group_world_size = funcol .get_world_size (group )
1026- #dist.get_world_size(group)
10271026
10281027 if scatter_idx == 2 and gather_idx == 1 :
10291028 B , S_LOCAL , H , D = x .shape
10301029 S = S_LOCAL * group_world_size
10311030 H_LOCAL = H // group_world_size
10321031
1032+ # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
10331033 x_temp = (x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D )
1034- .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
1034+ .transpose (0 , 2 ).contiguous ()
10351035 )
10361036
1037- out = torch .empty_like (x_temp )
10381037 if group_world_size > 1 :
1039- funcol .all_to_all_single (out , x_temp , None , None , group )
1038+ #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1039+ out = funcol .all_to_all_single (x_temp , None , None , group = group )
1040+ out = _wait_tensor (out )
10401041 else :
10411042 out = x_temp
10421043 out = out .reshape (S , B , H_LOCAL , D ).permute (1 , 0 , 2 , 3 ).contiguous ()
@@ -1050,19 +1051,20 @@ def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int =
10501051 #
10511052 x_temp = (x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D )
10521053 .permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D ))
1053- output = torch . empty_like ( x_temp )
1054+
10541055 if group_world_size > 1 :
1055- funcol .all_to_all_single (output , x_temp , None , None , group )
1056+ output = funcol .all_to_all_single (x_temp , None , None , group )
1057+ output = _wait_tensor (output )
10561058 else :
10571059 output = x_temp
10581060 output = output .reshape (H , S_LOCAL , B , D ).transpose (0 , 2 ).contiguous ()
10591061 output = output .reshape (B , S_LOCAL , H , D )
10601062 return output
10611063 else :
1062- raise ValueError ("Invalid scatter/gather indices for all_to_all_double ." )
1064+ raise ValueError ("Invalid scatter/gather indices for _all_to_all_dim_exchange ." )
10631065
10641066
1065- class SeqAllToAllDouble (torch .autograd .Function ):
1067+ class SeqAllToAllDim (torch .autograd .Function ):
10661068 @staticmethod
10671069 def forward ():
10681070 pass
0 commit comments