6
6
7
7
from triton_kernels_benchmark import flash_attention_benchmark
8
8
9
+ # FIXME: Revert temporary source code modification done in last commit of PR #4399.
10
+
9
11
10
12
# pylint: disable=unused-argument
11
13
@triton .jit
12
14
def _attn_fwd_inner (acc , l_i , m_i , q , #
13
- K_desc , V_desc , #
14
- start_m , qk_scale , #
15
+ desc_k , desc_v , #
16
+ offset_y , dtype : tl . constexpr , start_m , qk_scale , #
15
17
BLOCK_M : tl .constexpr , BLOCK_DMODEL : tl .constexpr , BLOCK_N : tl .constexpr , #
16
18
STAGE : tl .constexpr , offs_m : tl .constexpr , offs_n : tl .constexpr , #
17
19
N_CTX : tl .constexpr ):
@@ -24,13 +26,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
24
26
# causal = False
25
27
else :
26
28
lo , hi = 0 , N_CTX
27
- off_k = lo
28
- off_v = lo
29
+ offsetk_y = offset_y + lo
30
+ offsetv_y = offset_y + lo
29
31
# loop over k, v and update accumulator
30
32
for start_n in range (lo , hi , BLOCK_N ):
31
33
start_n = tl .multiple_of (start_n , BLOCK_N )
32
34
# -- compute qk ----
33
- k = K_desc .load ([0 , off_k ])
35
+ k = desc_k .load ([0 , offsetk_y ])
34
36
qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
35
37
qk += tl .dot (q , k )
36
38
if STAGE == 2 :
@@ -43,18 +45,20 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
43
45
qk = qk * qk_scale - m_ij [:, None ]
44
46
p = tl .math .exp2 (qk )
45
47
l_ij = tl .sum (p , 1 )
46
- # -- update m_i and l_i
48
+ # -- compute correction factor
47
49
alpha = tl .math .exp2 (m_i - m_ij )
48
50
l_i = l_i * alpha + l_ij
49
51
# -- update output accumulator --
50
52
acc = acc * alpha [:, None ]
51
- # update acc
52
- v = V_desc .load ([off_v , 0 ])
53
+ # prepare p and v for the dot
54
+ v = desc_v .load ([offsetv_y , 0 ])
55
+ # note that this non transposed v for FP8 is only supported on Blackwell
53
56
acc += tl .dot (p .to (tl .float16 ), v )
54
57
# update m_i and l_i
58
+ # place this at the end of the loop to reduce register pressure
55
59
m_i = m_ij
56
- off_v += BLOCK_N
57
- off_k += BLOCK_N
60
+ offsetk_y += BLOCK_N
61
+ offsetv_y += BLOCK_N
58
62
return acc , l_i , m_i
59
63
60
64
@@ -75,25 +79,28 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
75
79
BLOCK_N : tl .constexpr , #
76
80
STAGE : tl .constexpr #
77
81
): # pylint: disable=unused-argument
78
-
82
+ dtype = tl .float16
83
+ tl .static_assert (BLOCK_N <= BLOCK_DMODEL )
79
84
start_m = tl .program_id (2 )
80
85
off_z = tl .program_id (0 )
81
86
off_h = tl .program_id (1 )
82
- qvk_offset = off_z . to ( tl . int64 ) * stride_qz + off_h . to ( tl . int64 ) * stride_qh
87
+ offset_y = off_z * ( N_CTX * H ) + off_h * N_CTX
83
88
if N_CTX <= 512 :
84
89
start_m = tl .program_id (0 )
85
90
off_z = tl .program_id (2 )
86
- qvk_offset = off_z .to (tl .int64 ) * stride_qh
91
+ offset_y = off_z * N_CTX
92
+
93
+ y_dim = Z * H * N_CTX
94
+ desc_q = tl .make_tensor_descriptor (Q , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
95
+ block_shape = [BLOCK_M , BLOCK_DMODEL ])
96
+ desc_v = tl .make_tensor_descriptor (V , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
97
+ block_shape = [BLOCK_N , BLOCK_DMODEL ])
98
+ desc_k = tl .make_tensor_descriptor (K , shape = [BLOCK_DMODEL , y_dim ], strides = [1 , BLOCK_DMODEL ],
99
+ block_shape = [BLOCK_DMODEL , BLOCK_N ])
100
+ desc_o = tl .make_tensor_descriptor (Out , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
101
+ block_shape = [BLOCK_M , BLOCK_DMODEL ])
87
102
88
- # tensor descriptors
89
- Q_desc = tl .make_tensor_descriptor (base = Q + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ), strides = (stride_qm , stride_qk ),
90
- block_shape = (BLOCK_M , BLOCK_DMODEL ))
91
- V_desc = tl .make_tensor_descriptor (base = V + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ), strides = (stride_vk , stride_vn ),
92
- block_shape = (BLOCK_N , BLOCK_DMODEL ))
93
- K_desc = tl .make_tensor_descriptor (base = K + qvk_offset , shape = (BLOCK_DMODEL , N_CTX ), strides = (stride_kk , stride_kn ),
94
- block_shape = (BLOCK_DMODEL , BLOCK_N ))
95
- O_desc = tl .make_tensor_descriptor (base = Out + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ),
96
- strides = (stride_om , stride_on ), block_shape = (BLOCK_M , BLOCK_DMODEL ))
103
+ qo_offset_y = offset_y + start_m * BLOCK_M
97
104
# initialize offsets
98
105
offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
99
106
offs_n = tl .arange (0 , BLOCK_N )
@@ -105,27 +112,29 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
105
112
qk_scale = sm_scale
106
113
qk_scale *= 1.44269504 # 1/log(2)
107
114
# load q: it will stay in SRAM throughout
108
- q = Q_desc .load ([start_m * BLOCK_M , 0 ])
115
+ q = desc_q .load ([qo_offset_y , 0 ])
109
116
# stage 1: off-band
110
- # For causal = True, STAGE = 3, the kernel gets 1 as its STAGE
111
- # For causal = False, STAGE = 1, the kernel gets 3 as its STAGE
117
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
118
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
112
119
if STAGE & 1 :
113
- acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_desc , V_desc , #
114
- start_m , qk_scale , #
120
+ acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , #
121
+ desc_k , desc_v , #
122
+ offset_y , dtype , start_m , qk_scale , #
115
123
BLOCK_M , BLOCK_DMODEL , BLOCK_N , #
116
124
4 - STAGE , offs_m , offs_n , N_CTX #
117
125
)
118
126
# stage 2: on-band
119
127
if STAGE & 2 :
120
- acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_desc , V_desc , #
121
- start_m , qk_scale , #
128
+ acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , #
129
+ desc_k , desc_v , #
130
+ offset_y , dtype , start_m , qk_scale , #
122
131
BLOCK_M , BLOCK_DMODEL , BLOCK_N , #
123
132
2 , offs_m , offs_n , N_CTX #
124
133
)
125
134
# epilogue
126
135
m_i += tl .math .log2 (l_i )
127
136
acc = acc / l_i [:, None ]
128
- O_desc .store ([start_m * BLOCK_M , 0 ], acc .to (Out .type .element_ty ))
137
+ desc_o .store ([qo_offset_y , 0 ], acc .to (Out .type .element_ty ))
129
138
130
139
131
140
def get_benchmark (
0 commit comments