@@ -1040,7 +1040,33 @@ def backward(
10401040        grad_out : torch .Tensor ,
10411041        * args ,
10421042    ):
1043-         raise  NotImplementedError ("Backward pass is not implemented for TemplatedUlyssesAttention." )
1043+         parallel_config  =  _AttentionBackendRegistry ._parallel_config 
1044+         ulysses_mesh  =  parallel_config ._ulysses_mesh 
1045+         world_size  =  parallel_config .ulysses_degree 
1046+         group  =  ulysses_mesh .get_group ()
1047+ 
1048+         B , S_LOCAL , H , D  =  grad_out .shape 
1049+         H_LOCAL  =  H  //  world_size 
1050+ 
1051+         grad_out  =  grad_out .reshape (B , S_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1052+         grad_out  =  _wait_tensor (funcol .all_to_all_single (grad_out , None , None , group = group ))
1053+         grad_out  =  grad_out .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous ()
1054+ 
1055+         grad_query_op , grad_key_op , grad_value_op , * _  =  ctx .backward_op (ctx .op_ctx , grad_out )
1056+ 
1057+         grad_query , grad_key , grad_value  =  (
1058+             x .reshape (B , world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 0 , 2 , 4 ).contiguous ()
1059+             for  x  in  (grad_query_op , grad_key_op , grad_value_op )
1060+         )
1061+         grad_query , grad_key , grad_value  =  (
1062+             _wait_tensor (funcol .all_to_all_single (x , None , None , group = group ))
1063+             for  x  in  (grad_query , grad_key , grad_value )
1064+         )
1065+         grad_query , grad_key , grad_value  =  (
1066+             x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for  x  in  (grad_query , grad_key , grad_value )
1067+         )
1068+ 
1069+         return  grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None 
10441070
10451071
10461072def  _templated_context_parallel_attention (
0 commit comments