@@ -29,7 +29,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
29
29
offsetk_y = offset_y + lo
30
30
offsetv_y = offset_y + lo
31
31
# loop over k, v and update accumulator
32
- for start_n in range (lo , hi , BLOCK_N ):
32
+ for start_n in tl . range (lo , hi , BLOCK_N ):
33
33
start_n = tl .multiple_of (start_n , BLOCK_N )
34
34
# -- compute qk ----
35
35
k = desc_k .load ([0 , offsetk_y ])
@@ -44,16 +44,17 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
44
44
m_ij = tl .maximum (m_i , tl .max (qk , 1 ) * qk_scale )
45
45
qk = qk * qk_scale - m_ij [:, None ]
46
46
p = tl .math .exp2 (qk )
47
- l_ij = tl .sum (p , 1 )
48
47
# -- compute correction factor
49
48
alpha = tl .math .exp2 (m_i - m_ij )
49
+ l_ij = tl .sum (p , 1 )
50
50
l_i = l_i * alpha + l_ij
51
51
# -- update output accumulator --
52
52
acc = acc * alpha [:, None ]
53
53
# prepare p and v for the dot
54
54
v = desc_v .load ([offsetv_y , 0 ])
55
+ p = p .to (dtype )
55
56
# note that this non transposed v for FP8 is only supported on Blackwell
56
- acc + = tl .dot (p . to ( tl . float16 ) , v )
57
+ acc = tl .dot (p , v , acc )
57
58
# update m_i and l_i
58
59
# place this at the end of the loop to reduce register pressure
59
60
m_i = m_ij
0 commit comments