@@ -440,49 +440,25 @@ def forward(ctx, q, k, v, causal, sm_scale):
440440 assert Lq == Lk and Lk == Lv
441441 assert Lk in {16 , 32 , 64 , 128 }
442442 o = torch .empty_like (q )
443- BLOCK_M = 128
444- BLOCK_N = 64
445- num_stages = 3
446- num_warps = 8 if Lq == 64 else 16
447443 stage = 3 if causal else 1
448444 grid = lambda args : (q .shape [0 ], q .shape [1 ], triton .cdiv (q .shape [2 ], args ['BLOCK_M' ]))
449445 n_ctx = q .shape [2 ]
450446 if n_ctx <= 512 :
451447 grid = lambda args : (triton .cdiv (q .shape [2 ], args ['BLOCK_M' ]), 1 , q .shape [0 ] * q .shape [1 ])
452448 M = torch .empty ((q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32 )
453449
454- if os .getenv ('TRITON_INTEL_ADVANCED_PATH' , '0' ) == '0' :
455- # default pipeline
456- _attention .tune_attn_fwd [grid ]( # pylint: disable=unsubscriptable-object
457- q , k , v , sm_scale , M , o , #
458- q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
459- k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
460- v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
461- o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
462- q .shape [0 ], q .shape [1 ], #
463- N_CTX = q .shape [2 ], #
464- BLOCK_DMODEL = Lk , #
465- STAGE = stage , #
466- split_barriers_scope = 'None' , # possible scope value: 'Subgroup','Workgroup'
467- )
468- else :
469- _attention .attn_fwd [grid ]( # pylint: disable=unsubscriptable-object
470- q , k , v , sm_scale , M , o , #
471- q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
472- k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
473- v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
474- o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
475- q .shape [0 ], q .shape [1 ], #
476- N_CTX = q .shape [2 ], #
477- BLOCK_M = BLOCK_M , #
478- BLOCK_N = BLOCK_N , #
479- BLOCK_DMODEL = Lk , #
480- STAGE = stage , #
481- num_warps = num_warps , #
482- num_stages = num_stages , #
483- grf_mode = 'large' , #
484- advanced_path = True , #
485- )
450+ _attention .tune_attn_fwd [grid ]( # pylint: disable=unsubscriptable-object
451+ q , k , v , sm_scale , M , o , #
452+ q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
453+ k .stride (0 ), k .stride (1 ), k .stride (2 ), k .stride (3 ), #
454+ v .stride (0 ), v .stride (1 ), v .stride (2 ), v .stride (3 ), #
455+ o .stride (0 ), o .stride (1 ), o .stride (2 ), o .stride (3 ), #
456+ q .shape [0 ], q .shape [1 ], #
457+ N_CTX = q .shape [2 ], #
458+ BLOCK_DMODEL = Lk , #
459+ STAGE = stage , #
460+ split_barriers_scope = 'None' , # possible scope value: 'Subgroup','Workgroup'
461+ )
486462
487463 ctx .save_for_backward (q , k , v , o , M )
488464 ctx .sm_scale = sm_scale
0 commit comments