Skip to content

Commit 6955edf

Browse files
authored
align num_stages=3 between PoC and main (#2338)
Although `num_stages` in python file is always 3. We did not accept `num_stages` as a parameter of `add_prefetch_block` pass. So PoC branch uses `num_stages=3`, main branch uses `num_stages=4`.
1 parent 18932d1 commit 6955edf

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def forward(q, k, v, causal, sm_scale):
159159
o = torch.empty_like(q, dtype=torch.float32)
160160
BLOCK_M = 128
161161
BLOCK_N = 64 if Lk <= 64 else 32
162-
num_stages = 4 if Lk <= 64 else 3
162+
num_stages = 3
163163
num_warps = 8 if Lq == 64 else 16
164164
causal = False
165165
stage = 3 if causal else 1

0 commit comments

Comments
 (0)