diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index dc073d0e5c..606b67af52 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -171,7 +171,7 @@ def forward(q, k, v, causal, sm_scale): assert Lk in {16, 32, 64, 128} o = torch.empty_like(q, dtype=torch.float32) BLOCK_M = 128 - BLOCK_N = 64 if Lk <= 64 else 32 + BLOCK_N = 64 num_stages = 3 num_warps = 8 if Lq == 64 else 16 stage = 3 if causal else 1 @@ -205,7 +205,8 @@ def forward(q, k, v, causal, sm_scale): BLOCK_DMODEL=Lk, # STAGE=stage, # num_warps=num_warps, # - num_stages=num_stages # + num_stages=num_stages, # + grf_mode='large', # ) return o