@@ -1022,11 +1022,54 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10221022 return x
10231023
10241024def _all_to_all_double (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
1025- pass
1025+ group_world_size = funcol .get_world_size (group )
1026+ #dist.get_world_size(group)
1027+
1028+ if scatter_idx == 2 and gather_idx == 1 :
1029+ B , S_LOCAL , H , D = x .shape
1030+ S = S_LOCAL * group_world_size
1031+ H_LOCAL = H // group_world_size
1032+
1033+ x_temp = (x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D )
1034+ .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
1035+ )
1036+
1037+ out = torch .empty_like (x_temp )
1038+ if group_world_size > 1 :
1039+ funcol .all_to_all_single (out , x_temp , None , None , group )
1040+ else :
1041+ out = x_temp
1042+ out = out .reshape (S , B , H_LOCAL , D ).permute (1 , 0 , 2 , 3 ).contiguous ()
1043+ out = out .reshape (B , S , H_LOCAL , D )
1044+ return out
1045+ elif scatter_idx == 1 and gather_idx == 2 :
1046+ B , S , H_LOCAL , D = x .shape
1047+ H = H_LOCAL * group_world_size
1048+ S_LOCAL = S // group_world_size
1049+
1050+ #
1051+ x_temp = (x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D )
1052+ .permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D ))
1053+ output = torch .empty_like (x_temp )
1054+ if group_world_size > 1 :
1055+ funcol .all_to_all_single (output , x_temp , None , None , group )
1056+ else :
1057+ output = x_temp
1058+ output = output .reshape (H , S_LOCAL , B , D ).transpose (0 , 2 ).contiguous ()
1059+ output = output .reshape (B , S_LOCAL , H , D )
1060+ return output
1061+ else :
1062+ raise ValueError ("Invalid scatter/gather indices for all_to_all_double." )
10261063
10271064
10281065class SeqAllToAllDouble (torch .autograd .Function ):
1029- pass
1066+ @staticmethod
1067+ def forward ():
1068+ pass
1069+
1070+ @staticmethod
1071+ def backward ():
1072+ pass
10301073
10311074
10321075
0 commit comments