@@ -443,16 +443,16 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
443443 ENABLE_DROPOUT : tl .constexpr , RETURN_ENCODED_SOFTMAX : tl .constexpr , USE_ALIBI : tl .constexpr ,
444444 INT8 : tl .constexpr , USE_P_SCALE : tl .constexpr , INT8_KV : tl .constexpr ):
445445
446- if PERSISTENT : # if persistent, kernel loops over multiple tiles
446+ if PERSISTENT : # if persistent, kernel loops over multiple tiles
447447 NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched
448448 num_tiles_per_head = tl .cdiv (MAX_SEQLENS_Q , BLOCK_M ) # the number of work units (tiles) of a single head
449449 num_tiles_per_sample = num_tiles_per_head * HQ # times the number of heads
450450 num_tiles_total = num_tiles_per_sample * B # times the number of samples
451- if PERSISTENT_DYNAMIC :
451+ if PERSISTENT_DYNAMIC :
452452 tile_id = atomic_counter .atomic_add (1 ) # retuns the value BEFORE the atomic operation
453453 else :
454454 tile_id = tl .program_id (0 )
455- else : # standard, kernel processes only one tile
455+ else : # standard, kernel processes only one tile
456456 tile_id = 0
457457 num_tiles_total = 1
458458
@@ -466,7 +466,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
466466 start_m = tl .program_id (0 )
467467 off_h_q = tl .program_id (1 )
468468 off_z = tl .program_id (2 )
469-
469+
470470 offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
471471 offs_n = tl .arange (0 , BLOCK_N )
472472 offs_d = tl .arange (0 , BLOCK_DMODEL )
@@ -734,7 +734,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
734734 else :
735735 tile_id += NUM_WG
736736 else :
737- tile_id = num_tiles_total # break after single tile
737+ tile_id = num_tiles_total # break after single tile
738738
739739
740740@triton .jit
@@ -2017,8 +2017,7 @@ def main():
20172017 assert args .dtype in arg_to_torch_dtype , \
20182018 "Only fp16, bf16 and f32 types currently supported."
20192019
2020- test_op_fwd_int8 (4 , 4 , 65 , 1019 , 65 , True , True , 'bhsd' )
2021- # run_benchmark(custom_config, args)
2020+ run_benchmark (custom_config , args )
20222021
20232022
20242023if __name__ == '__main__' :
0 commit comments