@@ -265,14 +265,20 @@ def forward_kernel(
265265 qk += tl .sum (block_qk , 1 ) / 16.
266266 qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
267267
268+ # attention
269+
268270 m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
269271 p = tl .exp (qk * softmax_scale - m_ij [:, None ])
270272
271273 l_ij = tl .sum (p , 1 )
272274
275+ # renormalize the running output
276+
273277 acc_o_scale = tl .exp (m_i - m_ij )
274278 acc_o = acc_o * acc_o_scale [:, None ]
275279
280+ # aggregate values
281+
276282 v_block = tl .load (block_v_ptrs )
277283 v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
278284
@@ -748,6 +754,17 @@ def backward_kernel_one_col_block(
748754 block_k = tl .load (block_k_ptrs )
749755 block_v = tl .load (block_v_ptrs )
750756
757+ q_expanded = tl .expand_dims (q , 1 )
758+ q_expanded = tl .broadcast_to (q_expanded , (BLOCK , 16 , BLOCK_HEADDIM ))
759+
760+ block_k = tl .permute (block_k , (0 , 2 , 1 ))
761+ block_qk = tl .dot (q_expanded , block_k )
762+
763+ qk = tl .sum (block_qk , 1 ) / 16.
764+ qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
765+
766+ p = tl .exp (qk * softmax_scale - lse_i [:, None ])
767+
751768 # # increment pointers
752769 # dq_ptrs += BLOCK * stride_dqm
753770 # q_ptrs += BLOCK * stride_qm
0 commit comments