1010# pylint: disable=unused-argument
1111@triton .jit
1212def _attn_fwd_inner (acc , l_i , m_i , q , #
13- K_block_ptr , V_desc , #
13+ K_desc , V_desc , #
1414 start_m , qk_scale , #
1515 BLOCK_M : tl .constexpr , BLOCK_DMODEL : tl .constexpr , BLOCK_N : tl .constexpr , #
1616 STAGE : tl .constexpr , offs_m : tl .constexpr , offs_n : tl .constexpr , #
@@ -24,13 +24,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
2424 # causal = False
2525 else :
2626 lo , hi = 0 , N_CTX
27- K_block_ptr = tl . advance ( K_block_ptr , ( 0 , lo ))
27+ off_k = lo
2828 off_v = lo
2929 # loop over k, v and update accumulator
3030 for start_n in range (lo , hi , BLOCK_N ):
3131 start_n = tl .multiple_of (start_n , BLOCK_N )
3232 # -- compute qk ----
33- k = tl .load (K_block_ptr )
33+ k = K_desc .load ([ 0 , off_k ] )
3434 qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
3535 qk += tl .dot (q , k )
3636 if STAGE == 2 :
@@ -54,7 +54,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5454 # update m_i and l_i
5555 m_i = m_ij
5656 off_v += BLOCK_N
57- K_block_ptr = tl . advance ( K_block_ptr , ( 0 , BLOCK_N ))
57+ off_k += BLOCK_N
5858 return acc , l_i , m_i
5959
6060
@@ -90,9 +90,8 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
9090 block_shape = (BLOCK_M , BLOCK_DMODEL ))
9191 V_desc = tl .make_tensor_descriptor (base = V + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ), strides = (stride_vk , stride_vn ),
9292 block_shape = (BLOCK_N , BLOCK_DMODEL ))
93- #FIXME: change to a tensor descriptor.
94- K_block_ptr = tl .make_block_ptr (base = K + qvk_offset , shape = (BLOCK_DMODEL , N_CTX ), strides = (stride_kk , stride_kn ),
95- offsets = (0 , 0 ), block_shape = (BLOCK_DMODEL , BLOCK_N ), order = (0 , 1 ))
93+ K_desc = tl .make_tensor_descriptor (base = K + qvk_offset , shape = (BLOCK_DMODEL , N_CTX ), strides = (stride_kk , stride_kn ),
94+ block_shape = (BLOCK_DMODEL , BLOCK_N ))
9695 O_desc = tl .make_tensor_descriptor (base = Out + qvk_offset , shape = (N_CTX , BLOCK_DMODEL ),
9796 strides = (stride_om , stride_on ), block_shape = (BLOCK_M , BLOCK_DMODEL ))
9897 # initialize offsets
@@ -111,14 +110,14 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
111110 # For causal = True, STAGE = 3, the kernel gets 1 as its STAGE
112111 # For causal = False, STAGE = 1, the kernel gets 3 as its STAGE
113112 if STAGE & 1 :
114- acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_block_ptr , V_desc , #
113+ acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_desc , V_desc , #
115114 start_m , qk_scale , #
116115 BLOCK_M , BLOCK_DMODEL , BLOCK_N , #
117116 4 - STAGE , offs_m , offs_n , N_CTX #
118117 )
119118 # stage 2: on-band
120119 if STAGE & 2 :
121- acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_block_ptr , V_desc , #
120+ acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , K_desc , V_desc , #
122121 start_m , qk_scale , #
123122 BLOCK_M , BLOCK_DMODEL , BLOCK_N , #
124123 2 , offs_m , offs_n , N_CTX #
0 commit comments