@@ -1040,6 +1040,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10401040 out = _wait_tensor (out )
10411041 else :
10421042 out = x_temp
1043+ # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
10431044 out = out .reshape (S , B , H_LOCAL , D ).permute (1 , 0 , 2 , 3 ).contiguous ()
10441045 out = out .reshape (B , S , H_LOCAL , D )
10451046 return out
@@ -1053,6 +1054,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10531054 .permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D ))
10541055
10551056 if group_world_size > 1 :
1057+ #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
10561058 output = funcol .all_to_all_single (x_temp , None , None , group )
10571059 output = _wait_tensor (output )
10581060 else :
@@ -1066,12 +1068,15 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10661068
10671069class SeqAllToAllDim (torch .autograd .Function ):
10681070 @staticmethod
1069- def forward ():
1070- pass
1071+ def forward (ctx , group , input , scatter_id = 2 , gather_id = 1 ):
1072+ ctx .group = group
1073+ ctx .scatter_id = scatter_id
1074+ ctx .gather_id = gather_id
1075+ return _all_to_all_dim_exchange (input , scatter_id , gather_id , group )
10711076
10721077 @staticmethod
1073- def backward ():
1074- pass
1078+ def backward (ctx , * grad_outputs ):
1079+ return ( None , _all_to_all_dim_exchange ( grad_outputs [ 0 ], ctx . gather_id , ctx . scatter_id , ctx . group ), None , None )
10751080
10761081
10771082
@@ -1313,12 +1318,13 @@ def forward(ctx: torch.autograd.function.FunctionCtx,
13131318 ulysses_group = ulysses_mesh .get_group ()
13141319 ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
13151320 ring_group = ring_mesh .get_group ()
1321+ #hardcoded for now
13161322 scatter_idx = 2
13171323 gather_idx = 1
13181324
1319- query = SeqAllToAllDouble .apply (ulysses_group , query , scatter_idx , gather_idx )
1320- key = SeqAllToAllDouble .apply (ulysses_group , key , scatter_idx , gather_idx )
1321- value = SeqAllToAllDouble .apply (ulysses_group , value , scatter_idx , gather_idx )
1325+ query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
1326+ key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
1327+ value = SeqAllToAllDim .apply (ulysses_group , value , scatter_idx , gather_idx )
13221328 out = TemplatedRingAttention .apply (
13231329 query ,
13241330 key ,
@@ -1337,12 +1343,17 @@ def forward(ctx: torch.autograd.function.FunctionCtx,
13371343 context_layer , lse , * _ = out
13381344 else :
13391345 context_layer = out
1340- output = SeqAllToAllDouble .apply (
1346+ output = SeqAllToAllDim .apply (
13411347 ulysses_group ,
13421348 context_layer ,
13431349 gather_idx ,
13441350 scatter_idx ,
13451351 )
1352+ if return_lse :
1353+ # not sure if this is correct
1354+ lse = SeqAllToAllDim .apply (ulysses_group , lse , gather_idx , scatter_idx )
1355+ return (output , lse )
1356+ return output
13461357
13471358def _templated_context_parallel_attention (
13481359 query : torch .Tensor ,
0 commit comments