Skip to content

Commit 0d07174

Browse files
committed
fix
1 parent d754a31 commit 0d07174

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,16 +443,16 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
443443
ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr,
444444
INT8: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr):
445445

446-
if PERSISTENT: # if persistent, kernel loops over multiple tiles
446+
if PERSISTENT: # if persistent, kernel loops over multiple tiles
447447
NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched
448448
num_tiles_per_head = tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) # the number of work units (tiles) of a single head
449449
num_tiles_per_sample = num_tiles_per_head * HQ # times the number of heads
450450
num_tiles_total = num_tiles_per_sample * B # times the number of samples
451-
if PERSISTENT_DYNAMIC:
451+
if PERSISTENT_DYNAMIC:
452452
tile_id = atomic_counter.atomic_add(1) # retuns the value BEFORE the atomic operation
453453
else:
454454
tile_id = tl.program_id(0)
455-
else: # standard, kernel processes only one tile
455+
else: # standard, kernel processes only one tile
456456
tile_id = 0
457457
num_tiles_total = 1
458458

@@ -466,7 +466,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
466466
start_m = tl.program_id(0)
467467
off_h_q = tl.program_id(1)
468468
off_z = tl.program_id(2)
469-
469+
470470
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
471471
offs_n = tl.arange(0, BLOCK_N)
472472
offs_d = tl.arange(0, BLOCK_DMODEL)
@@ -734,7 +734,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
734734
else:
735735
tile_id += NUM_WG
736736
else:
737-
tile_id = num_tiles_total # break after single tile
737+
tile_id = num_tiles_total # break after single tile
738738

739739

740740
@triton.jit
@@ -2017,8 +2017,7 @@ def main():
20172017
assert args.dtype in arg_to_torch_dtype, \
20182018
"Only fp16, bf16 and f32 types currently supported."
20192019

2020-
test_op_fwd_int8(4, 4, 65, 1019, 65, True, True, 'bhsd')
2021-
# run_benchmark(custom_config, args)
2020+
run_benchmark(custom_config, args)
20222021

20232022

20242023
if __name__ == '__main__':

0 commit comments

Comments
 (0)