66
77from triton_kernels_benchmark import flash_attention_benchmark
88
9+ # FIXME: Revert temporary source code modification done in last commit of PR #4399.
10+
911
1012# pylint: disable=unused-argument
1113@triton .jit
1214def _attn_fwd_inner (acc , l_i , m_i , q , #
13- K_desc , V_desc , #
14- start_m , qk_scale , #
15+ desc_k , desc_v , #
16+ offset_y , dtype : tl . constexpr , start_m , qk_scale , #
1517 BLOCK_M : tl .constexpr , BLOCK_DMODEL : tl .constexpr , BLOCK_N : tl .constexpr , #
1618 STAGE : tl .constexpr , offs_m : tl .constexpr , offs_n : tl .constexpr , #
1719 N_CTX : tl .constexpr ):
@@ -24,13 +26,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
2426 # causal = False
2527 else :
2628 lo , hi = 0 , N_CTX
27- off_k = lo
28- off_v = lo
29+ offsetk_y = offset_y + lo
30+ offsetv_y = offset_y + lo
2931 # loop over k, v and update accumulator
3032 for start_n in range (lo , hi , BLOCK_N ):
3133 start_n = tl .multiple_of (start_n , BLOCK_N )
3234 # -- compute qk ----
33- k = K_desc .load ([0 , off_k ])
35+ k = desc_k .load ([0 , offsetk_y ])
3436 qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
3537 qk += tl .dot (q , k )
3638 if STAGE == 2 :
@@ -43,18 +45,20 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
4345 qk = qk * qk_scale - m_ij [:, None ]
4446 p = tl .math .exp2 (qk )
4547 l_ij = tl .sum (p , 1 )
46- # -- update m_i and l_i
48+ # -- compute correction factor
4749 alpha = tl .math .exp2 (m_i - m_ij )
4850 l_i = l_i * alpha + l_ij
4951 # -- update output accumulator --
5052 acc = acc * alpha [:, None ]
51- # update acc
52- v = V_desc .load ([off_v , 0 ])
53+ # prepare p and v for the dot
54+ v = desc_v .load ([offsetv_y , 0 ])
55+ # note that this non transposed v for FP8 is only supported on Blackwell
5356 acc += tl .dot (p .to (tl .float16 ), v )
5457 # update m_i and l_i
58+ # place this at the end of the loop to reduce register pressure
5559 m_i = m_ij
56- off_v += BLOCK_N
57- off_k += BLOCK_N
60+ offsetk_y += BLOCK_N
61+ offsetv_y += BLOCK_N
5862 return acc , l_i , m_i
5963
6064
@@ -75,25 +79,28 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
7579 BLOCK_N : tl .constexpr , #
7680 STAGE : tl .constexpr #
7781 ): # pylint: disable=unused-argument
78-
82+ dtype = tl .float16
83+ tl .static_assert (BLOCK_N <= BLOCK_DMODEL )
7984 start_m = tl .program_id (2 )
8085 off_z = tl .program_id (0 )
8186 off_h = tl .program_id (1 )
82- qvk_offset = off_z . to ( tl . int64 ) * stride_qz + off_h . to ( tl . int64 ) * stride_qh
87+ offset_y = off_z * ( N_CTX * H ) + off_h * N_CTX
8388 if N_CTX <= 512 :
8489 start_m = tl .program_id (0 )
8590 off_z = tl .program_id (2 )
86- qvk_offset = off_z .to (tl .int64 ) * stride_qh
91+ offset_y = off_z * N_CTX
92+
93+ y_dim = Z * H * N_CTX
94+ desc_q = tl .make_tensor_descriptor (Q , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
95+ block_shape = [BLOCK_M , BLOCK_DMODEL ])
96+ desc_v = tl .make_tensor_descriptor (V , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
97+ block_shape = [BLOCK_N , BLOCK_DMODEL ])
98+ desc_k = tl .make_tensor_descriptor (K , shape = [BLOCK_DMODEL , y_dim ], strides = [1 , BLOCK_DMODEL ],
99+ block_shape = [BLOCK_DMODEL , BLOCK_N ])
100+ desc_o = tl .make_tensor_descriptor (Out , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
101+ block_shape = [BLOCK_M , BLOCK_DMODEL ])
87102
88- # tensor descriptors
89- Q_desc = tl .make_tensor_descriptor (base = Q + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ), strides = (stride_qm , stride_qk ),
90- block_shape = (BLOCK_M , BLOCK_DMODEL ))
91- V_desc = tl .make_tensor_descriptor (base = V + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ), strides = (stride_vk , stride_vn ),
92- block_shape = (BLOCK_N , BLOCK_DMODEL ))
93- K_desc = tl .make_tensor_descriptor (base = K + qvk_offset , shape = (BLOCK_DMODEL , N_CTX ), strides = (stride_kk , stride_kn ),
94- block_shape = (BLOCK_DMODEL , BLOCK_N ))
95- O_desc = tl .make_tensor_descriptor (base = Out + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ),
96- strides = (stride_om , stride_on ), block_shape = (BLOCK_M , BLOCK_DMODEL ))
103+ qo_offset_y = offset_y + start_m * BLOCK_M
97104 # initialize offsets
98105 offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
99106 offs_n = tl .arange (0 , BLOCK_N )
@@ -105,27 +112,29 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
105112 qk_scale = sm_scale
106113 qk_scale *= 1.44269504 # 1/log(2)
107114 # load q: it will stay in SRAM throughout
108- q = Q_desc .load ([start_m * BLOCK_M , 0 ])
115+ q = desc_q .load ([qo_offset_y , 0 ])
109116 # stage 1: off-band
110- # For causal = True, STAGE = 3, the kernel gets 1 as its STAGE
111- # For causal = False, STAGE = 1, the kernel gets 3 as its STAGE
117+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
118+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
112119 if STAGE & 1 :
113- acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_desc , V_desc , #
114- start_m , qk_scale , #
120+ acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , #
121+ desc_k , desc_v , #
122+ offset_y , dtype , start_m , qk_scale , #
115123 BLOCK_M , BLOCK_DMODEL , BLOCK_N , #
116124 4 - STAGE , offs_m , offs_n , N_CTX #
117125 )
118126 # stage 2: on-band
119127 if STAGE & 2 :
120- acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_desc , V_desc , #
121- start_m , qk_scale , #
128+ acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , #
129+ desc_k , desc_v , #
130+ offset_y , dtype , start_m , qk_scale , #
122131 BLOCK_M , BLOCK_DMODEL , BLOCK_N , #
123132 2 , offs_m , offs_n , N_CTX #
124133 )
125134 # epilogue
126135 m_i += tl .math .log2 (l_i )
127136 acc = acc / l_i [:, None ]
128- O_desc .store ([start_m * BLOCK_M , 0 ], acc .to (Out .type .element_ty ))
137+ desc_o .store ([qo_offset_y , 0 ], acc .to (Out .type .element_ty ))
129138
130139
131140def get_benchmark (
0 commit comments