@@ -94,49 +94,72 @@ def forward_kernel(
9494 EVEN_N : tl .constexpr ,
9595 EVEN_HEADDIM : tl .constexpr ,
9696 BLOCK : tl .constexpr ,
97+ QUERY_HEAD_GROUPS : tl .constexpr ,
9798 NUM_SEL_KV_BLOCKS : tl .constexpr
9899):
99100 start_m = tl .program_id (0 )
100101 off_hb = tl .program_id (1 )
101102 off_b = off_hb // nheads
103+
102104 off_h = off_hb % nheads
103105
106+ offs_qh = off_h * QUERY_HEAD_GROUPS + tl .arange (0 , QUERY_HEAD_GROUPS )
107+
104108 offs_m = start_m * BLOCK + tl .arange (0 , BLOCK )
105109 offs_n = start_m * BLOCK + tl .arange (0 , BLOCK )
106110 offs_d = tl .arange (0 , BLOCK_HEADDIM )
107111
108112 q_ptrs = (
109- Q + off_b * stride_qb + off_h * stride_qh + (offs_m [:, None ] * stride_qm + offs_d [None , :])
113+ Q +
114+ off_b * stride_qb +
115+ offs_qh [:, None , None ] * stride_qh +
116+ offs_m [None , :, None ] * stride_qm +
117+ offs_d [None , None , :]
110118 )
119+
111120 k_ptrs = (
112- K + off_b * stride_kb + off_h * stride_kh + (offs_n [:, None ] * stride_kn + offs_d [None , :])
121+ K +
122+ off_b * stride_kb +
123+ off_h * stride_kh +
124+ offs_n [:, None ] * stride_kn +
125+ offs_d [None , :]
113126 )
127+
114128 v_ptrs = (
115- V + off_b * stride_vb + off_h * stride_vh + (offs_n [:, None ] * stride_vn + offs_d [None , :])
129+ V +
130+ off_b * stride_vb +
131+ off_h * stride_vh +
132+ offs_n [:, None ] * stride_vn +
133+ offs_d [None , :]
116134 )
117135
118136 # maximum
119137
120- m_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
138+ m_i = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
121139
122140 # lse
123141
124- lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
142+ offs_lse_qh = tl .arange (0 , QUERY_HEAD_GROUPS )
143+
144+ lse_ptrs = (
145+ Lse +
146+ (off_hb + offs_lse_qh [:, None ]) * seqlen_q_rounded +
147+ offs_m [None , :]
148+ )
125149
126- lse_i = tl .zeros ([BLOCK ], dtype = tl .float32 ) - float ("inf" )
150+ lse_i = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
127151
128152 # output
129153
130- offs_d = tl .arange (0 , BLOCK_HEADDIM )
131-
132154 out_ptrs = (
133- Out
134- + off_b * stride_ob
135- + off_h * stride_oh
136- + (offs_m [:, None ] * stride_om + offs_d [None , :])
155+ Out +
156+ off_b * stride_ob +
157+ offs_qh [:, None , None ] * stride_oh +
158+ offs_m [None , :, None ] * stride_om +
159+ offs_d [None , None , :]
137160 )
138161
139- acc_o = tl .zeros ([BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
162+ acc_o = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ], dtype = tl .float32 )
140163
141164 # load queries, keys, values
142165
@@ -153,6 +176,8 @@ def forward_kernel(
153176 q_ptrs , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ), other = 0.0
154177 )
155178
179+ q = q .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK_HEADDIM ])
180+
156181 if EVEN_N & EVEN_M :
157182 if EVEN_HEADDIM :
158183 k = tl .load (k_ptrs )
@@ -172,14 +197,18 @@ def forward_kernel(
172197 other = 0.0 ,
173198 )
174199
175- qk = tl .zeros ([BLOCK , BLOCK ], dtype = tl .float32 )
200+ qk = tl .zeros ([QUERY_HEAD_GROUPS * BLOCK , BLOCK ], dtype = tl .float32 )
176201 qk += tl .dot (q , tl .trans (k ))
177202
178203 if not EVEN_N :
179204 qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
180205
206+ qk = qk .reshape ([QUERY_HEAD_GROUPS , BLOCK , BLOCK ])
207+
181208 qk += tl .where (offs_m [:, None ] >= offs_n [None , :], 0 , float ("-inf" ))
182209
210+ qk = qk .reshape ([QUERY_HEAD_GROUPS * BLOCK , BLOCK ])
211+
183212 m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
184213 p = tl .exp (qk * softmax_scale - m_ij [:, None ])
185214
@@ -303,10 +332,13 @@ def forward_kernel(
303332
304333 # write back lse
305334
335+ lse_i = lse_i .reshape ([QUERY_HEAD_GROUPS , BLOCK ])
306336 tl .store (lse_ptrs , lse_i )
307337
308338 # write to output
309339
340+ acc_o = acc_o .reshape ([QUERY_HEAD_GROUPS , BLOCK , BLOCK_HEADDIM ])
341+
310342 if EVEN_M :
311343 if EVEN_HEADDIM :
312344 tl .store (out_ptrs , acc_o )
@@ -331,13 +363,15 @@ def flash_attn_forward(
331363 q , k , v , kv_block_indices = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v , kv_block_indices )]
332364
333365 batch , nheads , seqlen_q , dim , device = * q .shape , q .device
334- _ , _ , seqlen_k , _ = k .shape
366+ _ , kv_heads , seqlen_k , _ = k .shape
367+ assert divisible_by (nheads , kv_heads )
368+ head_groups = nheads // kv_heads
335369
336370 num_selected_fine_blocks = kv_block_indices .shape [- 1 ]
337371 assert kv_block_indices .shape == kv_block_mask .shape
338372
339- assert k .shape == (batch , nheads , seqlen_k , dim )
340- assert v .shape == (batch , nheads , seqlen_k , dim )
373+ assert k .shape == (batch , kv_heads , seqlen_k , dim )
374+ assert v .shape == (batch , kv_heads , seqlen_k , dim )
341375 assert dim <= 128 , "only support head dimensions up to 128"
342376 assert q .dtype == k .dtype == v .dtype , "All tensors must have the same type"
343377 assert q .dtype in [torch .float16 , torch .bfloat16 ], "Only support fp16 and bf16"
@@ -353,7 +387,8 @@ def flash_attn_forward(
353387
354388 BLOCK_HEADDIM = max (triton .next_power_of_2 (dim ), 16 )
355389 num_warps = 4 if dim <= 64 else 8
356- grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
390+
391+ grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * kv_heads ) # kv heads here, as grouped query heads all loaded, following the paper
357392
358393 forward_kernel [grid ](
359394 q ,
@@ -388,6 +423,7 @@ def flash_attn_forward(
388423 seqlen_k // 32 ,
389424 BLOCK_HEADDIM ,
390425 BLOCK = block_size ,
426+ QUERY_HEAD_GROUPS = head_groups ,
391427 NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
392428 num_warps = num_warps ,
393429 num_stages = 1 ,
@@ -1090,8 +1126,6 @@ def forward(
10901126 assert divisible_by (q_heads , kv_heads )
10911127 head_groups = q_heads // kv_heads
10921128
1093- fk , fv , selected_block_indices , fmask = tuple (repeat (t , 'b h ... -> b (h g) ...' , g = head_groups ) for t in (fk , fv , selected_block_indices , fmask ))
1094-
10951129 fq , fk , fv = tuple (t .half () for t in (fq , fk , fv ))
10961130
10971131 out , lse = flash_attn_forward (
@@ -1101,6 +1135,8 @@ def forward(
11011135 block_size = block_size
11021136 )
11031137
1138+ fk , fv , selected_block_indices , fmask = tuple (repeat (t , 'b h ... -> b (h g) ...' , g = head_groups ) for t in (fk , fv , selected_block_indices , fmask ))
1139+
11041140 ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
11051141
11061142 ctx ._saved_variables = (
0 commit comments