2424import triton
2525import triton .language as tl
2626
27+ from vllm .utils import is_navi
28+
2729torch_dtype : tl .constexpr = torch .float16
2830
2931
@@ -217,88 +219,80 @@ def _attn_fwd_inner(
217219 return acc , l_i , m_i
218220
219221
220- @ triton . autotune (
221- configs = [
222+ def get_cdna_autotune_configs ():
223+ return [
222224 triton .Config (
223225 {
224- " BLOCK_M" : 256 ,
225- " BLOCK_N" : 64 ,
226- " waves_per_eu" : 2 ,
227- " PRE_LOAD_V" : False ,
226+ ' BLOCK_M' : 256 ,
227+ ' BLOCK_N' : 64 ,
228+ ' waves_per_eu' : 2 ,
229+ ' PRE_LOAD_V' : False
228230 },
229231 num_stages = 1 ,
230- num_warps = 8 ,
231- ),
232+ num_warps = 8 ),
232233 triton .Config (
233234 {
234- " BLOCK_M" : 128 ,
235- " BLOCK_N" : 128 ,
236- " waves_per_eu" : 2 ,
237- " PRE_LOAD_V" : False ,
235+ ' BLOCK_M' : 128 ,
236+ ' BLOCK_N' : 128 ,
237+ ' waves_per_eu' : 2 ,
238+ ' PRE_LOAD_V' : False
238239 },
239240 num_stages = 1 ,
240- num_warps = 4 ,
241- ),
241+ num_warps = 4 ),
242242 triton .Config (
243243 {
244- " BLOCK_M" : 256 ,
245- " BLOCK_N" : 128 ,
246- " waves_per_eu" : 2 ,
247- " PRE_LOAD_V" : False ,
244+ ' BLOCK_M' : 256 ,
245+ ' BLOCK_N' : 128 ,
246+ ' waves_per_eu' : 2 ,
247+ ' PRE_LOAD_V' : False
248248 },
249249 num_stages = 1 ,
250- num_warps = 8 ,
251- ),
250+ num_warps = 8 ),
252251 triton .Config (
253252 {
254- " BLOCK_M" : 128 ,
255- " BLOCK_N" : 64 ,
256- " waves_per_eu" : 1 ,
257- " PRE_LOAD_V" : False ,
253+ ' BLOCK_M' : 128 ,
254+ ' BLOCK_N' : 64 ,
255+ ' waves_per_eu' : 1 ,
256+ ' PRE_LOAD_V' : False
258257 },
259258 num_stages = 1 ,
260- num_warps = 4 ,
261- ),
259+ num_warps = 4 ),
262260 triton .Config (
263261 {
264- " BLOCK_M" : 128 ,
265- " BLOCK_N" : 64 ,
266- " waves_per_eu" : 3 ,
267- " PRE_LOAD_V" : True ,
262+ ' BLOCK_M' : 128 ,
263+ ' BLOCK_N' : 64 ,
264+ ' waves_per_eu' : 3 ,
265+ ' PRE_LOAD_V' : True
268266 },
269267 num_stages = 1 ,
270- num_warps = 4 ,
271- ),
268+ num_warps = 4 ),
272269 triton .Config (
273270 {
274- " BLOCK_M" : 128 ,
275- " BLOCK_N" : 64 ,
276- " waves_per_eu" : 3 ,
277- " PRE_LOAD_V" : False ,
271+ ' BLOCK_M' : 128 ,
272+ ' BLOCK_N' : 64 ,
273+ ' waves_per_eu' : 3 ,
274+ ' PRE_LOAD_V' : False
278275 },
279276 num_stages = 1 ,
280- num_warps = 4 ,
281- ),
277+ num_warps = 4 ),
282278 triton .Config (
283279 {
284- " BLOCK_M" : 64 ,
285- " BLOCK_N" : 64 ,
286- " waves_per_eu" : 4 ,
287- " PRE_LOAD_V" : False ,
280+ ' BLOCK_M' : 64 ,
281+ ' BLOCK_N' : 64 ,
282+ ' waves_per_eu' : 4 ,
283+ ' PRE_LOAD_V' : False
288284 },
289285 num_stages = 1 ,
290- num_warps = 8 ,
291- ),
286+ num_warps = 8 ),
292287 triton .Config (
293288 {
294- " BLOCK_M" : 32 ,
295- " BLOCK_N" : 32 ,
296- " waves_per_eu" : 4 ,
297- " PRE_LOAD_V" : False ,
289+ ' BLOCK_M' : 32 ,
290+ ' BLOCK_N' : 32 ,
291+ ' waves_per_eu' : 4 ,
292+ ' PRE_LOAD_V' : False
298293 },
299294 num_stages = 1 ,
300- num_warps = 8 ,
301- ),
295+ num_warps = 8 ),
302296 # TODO: This config fails with head_size not pow2 with data mismatches.
303297 # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
304298 # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
@@ -314,8 +308,93 @@ def _attn_fwd_inner(
314308 # num_stages=1,
315309 # num_warps=4,
316310 # ),
317- ],
318- key = ['IS_CAUSAL' , 'dropout_p' , 'BLOCK_DMODEL' , 'USE_FP8' ],
311+ ], ['IS_CAUSAL' , 'dropout_p' , 'BLOCK_DMODEL' , 'USE_FP8' ]
312+
313+
314+ def get_rdna_autotune_configs ():
315+ return [
316+ triton .Config (
317+ {
318+ 'BLOCK_M' : 32 ,
319+ 'BLOCK_N' : 32 ,
320+ 'waves_per_eu' : 4 ,
321+ 'PRE_LOAD_V' : False
322+ },
323+ num_stages = 1 ,
324+ num_warps = 2 ),
325+ triton .Config (
326+ {
327+ 'BLOCK_M' : 32 ,
328+ 'BLOCK_N' : 32 ,
329+ 'waves_per_eu' : 2 ,
330+ 'PRE_LOAD_V' : False
331+ },
332+ num_stages = 1 ,
333+ num_warps = 2 ),
334+ triton .Config (
335+ {
336+ 'BLOCK_M' : 32 ,
337+ 'BLOCK_N' : 16 ,
338+ 'waves_per_eu' : 4 ,
339+ 'PRE_LOAD_V' : False
340+ },
341+ num_stages = 1 ,
342+ num_warps = 2 ),
343+ triton .Config (
344+ {
345+ 'BLOCK_M' : 32 ,
346+ 'BLOCK_N' : 16 ,
347+ 'waves_per_eu' : 2 ,
348+ 'PRE_LOAD_V' : False
349+ },
350+ num_stages = 1 ,
351+ num_warps = 2 ),
352+ # Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
353+ # triton.Config(
354+ # {
355+ # 'BLOCK_M': 16,
356+ # 'BLOCK_N': 16,
357+ # 'waves_per_eu': 4,
358+ # 'PRE_LOAD_V': False
359+ # },
360+ # num_stages=1,
361+ # num_warps=2),
362+ # triton.Config(
363+ # {
364+ # 'BLOCK_M': 16,
365+ # 'BLOCK_N': 16,
366+ # 'waves_per_eu': 2,
367+ # 'PRE_LOAD_V': False
368+ # },
369+ # num_stages=1,
370+ # num_warps=2),
371+ # # Fall-back config.
372+ # triton.Config(
373+ # {
374+ # 'BLOCK_M': 16,
375+ # 'BLOCK_N': 16,
376+ # 'waves_per_eu': 1,
377+ # 'PRE_LOAD_V': False
378+ # },
379+ # num_stages=1,
380+ # num_warps=2),
381+ ], ['IS_CAUSAL' , 'dropout_p' , 'BLOCK_DMODEL' , 'USE_FP8' ]
382+
383+
384+ def get_autotune_configs ():
385+ if is_navi ():
386+ return get_rdna_autotune_configs ()
387+ else :
388+ return get_cdna_autotune_configs ()
389+
390+
391+ autotune_configs , autotune_keys = get_autotune_configs ()
392+
393+
394+ @triton .autotune (
395+ configs = autotune_configs ,
396+ key = autotune_keys ,
397+ use_cuda_graph = True ,
319398)
320399@triton .jit
321400def attn_fwd (
@@ -833,6 +912,10 @@ def check_and_convert(t, scale):
833912 p_descale = 1.0 / p_scale
834913 o_descale = 1.0 / o_scale
835914
915+ if is_navi ():
916+ max_seqlens_q = 0
917+ max_seqlens_k = 0
918+
836919 attn_fwd [grid ](
837920 q ,
838921 k ,
0 commit comments