@@ -243,9 +243,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
243243 BLOCK_DMODEL : tl .constexpr , BLOCK_N : tl .constexpr , OFFS_M : tl .constexpr , OFFS_N : tl .constexpr ,
244244 PRE_LOAD_V : tl .constexpr , MASK_STEPS : tl .constexpr , ENABLE_DROPOUT : tl .constexpr ,
245245 RETURN_ENCODED_SOFTMAX : tl .constexpr , PADDED_HEAD : tl .constexpr , ACTUAL_BLOCK_DMODEL : tl .constexpr ,
246- QK_SCALE : tl .constexpr , INT8_GEMM : tl .constexpr , USE_P_SCALE : tl .constexpr , INT8_KV : tl .constexpr ):
246+ QK_SCALE : tl .constexpr , INT8_GEMM : tl .constexpr , USE_P_SCALE : tl .constexpr , INT8_KV : tl .constexpr ,
247+ ENABLE_PIPELINING : tl .constexpr ):
247248 # loop over k, v, and update accumulator
248- for start_n in range (block_min , block_max , BLOCK_N ):
249+ num_stages : tl .constexpr = None if ENABLE_PIPELINING else 1 # Set num_stages==1 if we want to disable pipelining
250+ for start_n in tl .range (block_min , block_max , BLOCK_N , num_stages = num_stages ):
249251 # For padded blocks, we will overrun the tensor size if
250252 # we load all BLOCK_N. For others, the blocks are all within range.
251253 if MASK_STEPS :
@@ -674,7 +676,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
674676 # _, MASK_STEPS, ...
675677 PRE_LOAD_V , False , ENABLE_DROPOUT , RETURN_ENCODED_SOFTMAX ,
676678 PADDED_HEAD , ACTUAL_BLOCK_DMODEL , QK_SCALE , INT8_GEMM , USE_P_SCALE ,
677- INT8_KV )
679+ INT8_KV , True )
678680 block_min = block_max
679681 block_max = n_blocks * BLOCK_N
680682
@@ -698,7 +700,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
698700 p_scale , IS_CAUSAL , BLOCK_M , BLOCK_DMODEL , BLOCK_N , offs_m , offs_n ,
699701 # _, MASK_STEPS, ...
700702 PRE_LOAD_V , True , ENABLE_DROPOUT , RETURN_ENCODED_SOFTMAX , PADDED_HEAD , ACTUAL_BLOCK_DMODEL ,
701- QK_SCALE , INT8_GEMM , USE_P_SCALE , INT8_KV )
703+ QK_SCALE , INT8_GEMM , USE_P_SCALE , INT8_KV , False )
702704
703705 if INT8 and not INT8_KV :
704706 if USE_P_SCALE :
0 commit comments