@@ -1021,6 +1021,14 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10211021 x = _wait_tensor (x )
10221022 return x
10231023
1024+ def _all_to_all_double (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
1025+ pass
1026+
1027+
1028+ class SeqAllToAllDouble (torch .autograd .Function ):
1029+ pass
1030+
1031+
10241032
10251033class TemplatedRingAttention (torch .autograd .Function ):
10261034 @staticmethod
@@ -1240,6 +1248,56 @@ def backward(
12401248
12411249 return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
12421250
1251+ class TemplatedUnifiedAttention (torch .nn .Module ):
1252+ @staticmethod
1253+ def forward (ctx : torch .autograd .function .FunctionCtx ,
1254+ query : torch .Tensor ,
1255+ key : torch .Tensor ,
1256+ value : torch .Tensor ,
1257+ attn_mask : Optional [torch .Tensor ],
1258+ dropout_p : float ,
1259+ is_causal : bool ,
1260+ scale : Optional [float ],
1261+ enable_gqa : bool ,
1262+ return_lse : bool ,
1263+ forward_op ,
1264+ backward_op ,
1265+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1266+ ):
1267+ ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1268+ ulysses_group = ulysses_mesh .get_group ()
1269+ ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
1270+ ring_group = ring_mesh .get_group ()
1271+ scatter_idx = 2
1272+ gather_idx = 1
1273+
1274+ query = SeqAllToAllDouble .apply (ulysses_group , query , scatter_idx , gather_idx )
1275+ key = SeqAllToAllDouble .apply (ulysses_group , key , scatter_idx , gather_idx )
1276+ value = SeqAllToAllDouble .apply (ulysses_group , value , scatter_idx , gather_idx )
1277+ out = TemplatedRingAttention .apply (
1278+ query ,
1279+ key ,
1280+ value ,
1281+ attn_mask ,
1282+ dropout_p ,
1283+ is_causal ,
1284+ scale ,
1285+ enable_gqa ,
1286+ return_lse ,
1287+ forward_op ,
1288+ backward_op ,
1289+ _parallel_config ,
1290+ )
1291+ if return_lse :
1292+ context_layer , lse , * _ = out
1293+ else :
1294+ context_layer = out
1295+ output = SeqAllToAllDouble .apply (
1296+ ulysses_group ,
1297+ context_layer ,
1298+ gather_idx ,
1299+ scatter_idx ,
1300+ )
12431301
12441302def _templated_context_parallel_attention (
12451303 query : torch .Tensor ,
0 commit comments