66
77from triton_kernels_benchmark import flash_attention_benchmark
88
9- # FIXME: Revert temporary source code modification done in last commit of PR #4399.
10-
119
1210# pylint: disable=unused-argument
1311@triton .jit
1412def _attn_fwd_inner (acc , l_i , m_i , q , #
1513 desc_k , desc_v , #
1614 offset_y , dtype : tl .constexpr , start_m , qk_scale , #
17- BLOCK_M : tl .constexpr , BLOCK_DMODEL : tl .constexpr , BLOCK_N : tl .constexpr , #
15+ BLOCK_M : tl .constexpr , HEAD_DIM : tl .constexpr , BLOCK_N : tl .constexpr , #
1816 STAGE : tl .constexpr , offs_m : tl .constexpr , offs_n : tl .constexpr , #
1917 N_CTX : tl .constexpr ):
2018 # range of values handled by this stage
@@ -32,7 +30,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
3230 for start_n in tl .range (lo , hi , BLOCK_N ):
3331 start_n = tl .multiple_of (start_n , BLOCK_N )
3432 # -- compute qk ----
35- k = desc_k .load ([0 , offsetk_y ])
33+ k = desc_k .load ([offsetk_y , 0 ]). T
3634 qk = tl .dot (q , k )
3735 if STAGE == 2 :
3836 mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
@@ -95,8 +93,8 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
9593 block_shape = [BLOCK_M , BLOCK_DMODEL ])
9694 desc_v = tl .make_tensor_descriptor (V , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
9795 block_shape = [BLOCK_N , BLOCK_DMODEL ])
98- desc_k = tl .make_tensor_descriptor (K , shape = [BLOCK_DMODEL , y_dim ], strides = [1 , BLOCK_DMODEL ],
99- block_shape = [BLOCK_DMODEL , BLOCK_N ])
96+ desc_k = tl .make_tensor_descriptor (K , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
97+ block_shape = [BLOCK_N , BLOCK_DMODEL ])
10098 desc_o = tl .make_tensor_descriptor (Out , shape = [y_dim , BLOCK_DMODEL ], strides = [BLOCK_DMODEL , 1 ],
10199 block_shape = [BLOCK_M , BLOCK_DMODEL ])
102100
0 commit comments