@@ -856,6 +856,19 @@ def _wait_tensor(tensor):
856856    return  tensor 
857857
858858
859+ def  _all_to_all_single (x : torch .Tensor , group ) ->  torch .Tensor :
860+     shape  =  x .shape 
861+     # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization 
862+     # to benchmark triton codegen fails somewhere: 
863+     # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') 
864+     # ValueError: Tensors must be contiguous 
865+     x  =  x .flatten ()
866+     x  =  funcol .all_to_all_single (x , None , None , group )
867+     x  =  x .reshape (shape )
868+     x  =  _wait_tensor (x )
869+     return  x 
870+ 
871+ 
859872class  TemplatedRingAttention (torch .autograd .Function ):
860873    @staticmethod  
861874    def  forward (
@@ -1003,28 +1016,26 @@ def forward(
10031016        ctx .backward_op  =  backward_op 
10041017        ctx .op_ctx  =  torch .autograd .function .FunctionCtx ()
10051018
1006-         B , S_LOCAL , H , D  =  query .shape 
1019+         B , S_Q_LOCAL , H , D  =  query .shape 
1020+         _ , S_KV_LOCAL , _ , _  =  key .shape 
10071021        H_LOCAL  =  H  //  world_size 
1008-         query , key , value  =  (
1009-             x .reshape (B , S_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1010-             for  x  in  (query , key , value )
1011-         )
1012-         query , key , value  =  (
1013-             _wait_tensor (funcol .all_to_all_single (x , None , None , group = group )) for  x  in  (query , key , value )
1014-         )
1022+         query  =  query .reshape (B , S_Q_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1023+         key  =  key .reshape (B , S_KV_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1024+         value  =  value .reshape (B , S_KV_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1025+         query , key , value  =  (_all_to_all_single (x , group ) for  x  in  (query , key , value ))
10151026        query , key , value  =  (x .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous () for  x  in  (query , key , value ))
10161027
10171028        out  =  forward_op (ctx .op_ctx , query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse )
10181029        if  return_lse :
10191030            out , lse , * _  =  out 
10201031
1021-         out  =  out .reshape (B , world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 0 , 2 , 4 ).contiguous ()
1022-         out  =  _wait_tensor ( funcol . all_to_all_single ( out , None ,  None ,  group = group ) )
1032+         out  =  out .reshape (B , world_size , S_Q_LOCAL , H_LOCAL , D ).permute (1 , 3 , 0 , 2 , 4 ).contiguous ()
1033+         out  =  _all_to_all_single ( out , group )
10231034        out  =  out .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous ()
10241035
10251036        if  return_lse :
1026-             lse  =  lse .reshape (B , world_size , S_LOCAL , H_LOCAL ).permute (1 , 3 , 0 , 2 ).contiguous ()
1027-             lse  =  _wait_tensor ( funcol . all_to_all_single ( lse , None ,  None ,  group = group ) )
1037+             lse  =  lse .reshape (B , world_size , S_Q_LOCAL , H_LOCAL ).permute (1 , 3 , 0 , 2 ).contiguous ()
1038+             lse  =  _all_to_all_single ( lse , group )
10281039            lse  =  lse .flatten (0 , 1 ).permute (1 , 2 , 0 ).contiguous ()
10291040        else :
10301041            lse  =  None 
@@ -1046,7 +1057,7 @@ def backward(
10461057        H_LOCAL  =  H  //  world_size 
10471058
10481059        grad_out  =  grad_out .reshape (B , S_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1049-         grad_out  =  _wait_tensor ( funcol . all_to_all_single ( grad_out , None ,  None ,  group = group ) )
1060+         grad_out  =  _all_to_all_single ( grad_out , group )
10501061        grad_out  =  grad_out .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous ()
10511062
10521063        grad_query_op , grad_key_op , grad_value_op , * _  =  ctx .backward_op (ctx .op_ctx , grad_out )
@@ -1055,10 +1066,7 @@ def backward(
10551066            x .reshape (B , world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 0 , 2 , 4 ).contiguous ()
10561067            for  x  in  (grad_query_op , grad_key_op , grad_value_op )
10571068        )
1058-         grad_query , grad_key , grad_value  =  (
1059-             _wait_tensor (funcol .all_to_all_single (x , None , None , group = group ))
1060-             for  x  in  (grad_query , grad_key , grad_value )
1061-         )
1069+         grad_query , grad_key , grad_value  =  (_all_to_all_single (x , group ) for  x  in  (grad_query , grad_key , grad_value ))
10621070        grad_query , grad_key , grad_value  =  (
10631071            x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for  x  in  (grad_query , grad_key , grad_value )
10641072        )
0 commit comments