Skip to content

Commit 5c4b1fb

Browse files
committed
[FA] Disable pipelining for causal loop
1 parent 1082cd2 commit 5c4b1fb

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

fa/flash-attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
243243
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr,
244244
PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
245245
RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
246-
QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr):
246+
QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr,
247+
ENABLE_PIPELINING: tl.constexpr):
247248
# loop over k, v, and update accumulator
248-
for start_n in range(block_min, block_max, BLOCK_N):
249+
num_stages: tl.constexpr = None if ENABLE_PIPELINING else 1 # Set num_stages==1 if we want to disable pipelining
250+
for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages):
249251
# For padded blocks, we will overrun the tensor size if
250252
# we load all BLOCK_N. For others, the blocks are all within range.
251253
if MASK_STEPS:
@@ -674,7 +676,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
674676
# _, MASK_STEPS, ...
675677
PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX,
676678
PADDED_HEAD, ACTUAL_BLOCK_DMODEL, QK_SCALE, INT8_GEMM, USE_P_SCALE,
677-
INT8_KV)
679+
INT8_KV, True)
678680
block_min = block_max
679681
block_max = n_blocks * BLOCK_N
680682

@@ -698,7 +700,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
698700
p_scale, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
699701
# _, MASK_STEPS, ...
700702
PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL,
701-
QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV)
703+
QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV, False)
702704

703705
if INT8 and not INT8_KV:
704706
if USE_P_SCALE:

third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@ FourStagePipeliner::FourStagePipeliner(scf::ForOp _forOp, int _numStages,
7474
}
7575

7676
bool FourStagePipeliner::checkPrecondition(scf::ForOp forOp, int numStages) {
77-
// Skip the second loop (causual loop)
78-
static bool isFirst = true;
79-
if (!isFirst)
80-
return false;
81-
isFirst = false;
82-
8377
unsigned dotCount{};
8478
unsigned reduceCount{};
8579

0 commit comments

Comments
 (0)