diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 9409452aa5..a31290d850 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -78,10 +78,11 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # start_m = tl.program_id(2) off_z = tl.program_id(0) off_h = tl.program_id(1) + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh if N_CTX <= 512: start_m = tl.program_id(0) off_z = tl.program_id(2) - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + qvk_offset = off_z.to(tl.int64) * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( @@ -181,7 +182,7 @@ def forward(q, k, v, causal, sm_scale): grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M'])) n_ctx = q.shape[2] if n_ctx <= 512: - grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0]) + grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), 1, q.shape[0] * q.shape[1]) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':