@@ -50,8 +50,6 @@ def _matmul_launch_metadata(grid, kernel, args):
5050 ret = {}
5151 M , N , K = args ["M" ], args ["N" ], args ["K" ]
5252 ret ["name" ] = f"{ kernel .name } [M={ M } , N={ N } , K={ K } ]"
53- if "tiles_per_update" in args :
54- ret ["name" ] = f"{ kernel .name } [M={ M } , N={ N } , K={ K } , tiles_per_update={ args ['tiles_per_update' ]:02} ]"
5553 if "c_ptr" in args :
5654 bytes_per_elem = args ["c_ptr" ].element_size ()
5755 else :
@@ -376,8 +374,7 @@ def matmul_tma_persistent(a, b):
376374
377375
378376@triton .jit (launch_metadata = _matmul_launch_metadata )
379- def matmul_kernel_descriptor_persistent (tiles_per_update : tl .constexpr , #
380- a_ptr , b_ptr , c_ptr , #
377+ def matmul_kernel_descriptor_persistent (a_ptr , b_ptr , c_ptr , #
381378 M , N , K , #
382379 BLOCK_SIZE_M : tl .constexpr , #
383380 BLOCK_SIZE_N : tl .constexpr , #
@@ -417,7 +414,6 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
417414
418415 tile_id = start_pid - NUM_SMS
419416 ki = - 1
420- ni = - 1
421417
422418 pid_m = 0
423419 pid_n = 0
@@ -427,36 +423,10 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
427423 num_pid_in_group = GROUP_SIZE_M * num_pid_n
428424
429425 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
430- # Create an opaque value to prevent the descriptor creation from being
431- # hoisted out of the loop
432- zero = tl .inline_asm_elementwise ("mov.b32 $0, 0;" , "=r" , [], dtype = tl .int32 , is_pure = True , pack = 1 )
433426
434427 for _ in range (0 , k_tiles * tiles_per_SM ):
435428 ki = tl .where (ki == k_tiles - 1 , 0 , ki + 1 )
436429 if ki == 0 :
437- ni += 1
438-
439- # Simulate a grouped gemm
440- if ni == tiles_per_update :
441- a_desc = tl ._experimental_make_tensor_descriptor (
442- a_ptr + zero ,
443- shape = [M , K ],
444- strides = [K , 1 ],
445- block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_K ],
446- )
447- b_desc = tl ._experimental_make_tensor_descriptor (
448- b_ptr + zero ,
449- shape = [N , K ],
450- strides = [K , 1 ],
451- block_shape = [BLOCK_SIZE_N , BLOCK_SIZE_K ],
452- )
453- c_desc = tl ._experimental_make_tensor_descriptor (
454- c_ptr + zero ,
455- shape = [M , N ],
456- strides = [N , 1 ],
457- block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_N ],
458- )
459- ni = 0
460430
461431 tile_id += NUM_SMS
462432 group_id = tile_id // num_pid_in_group
@@ -482,8 +452,7 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, #
482452 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
483453
484454
485- def matmul_descriptor_persistent (a , b , tiles_per_update ):
486- # Autotuner does not work with TMA. Use manual config.
455+ def matmul_descriptor_persistent (a , b ):
487456 configs = {
488457 torch .float8_e4m3fn : {
489458 "BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 256 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 , "num_stages" : 4 ,
@@ -513,7 +482,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
513482
514483 grid = lambda META : (min (NUM_SMS , triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ])), )
515484 matmul_kernel_descriptor_persistent [grid ](
516- tiles_per_update , #
517485 a , b , c , #
518486 M , N , K , #
519487 BLOCK_SIZE_M = configs [dtype ]["BLOCK_SIZE_M" ], #
@@ -570,7 +538,7 @@ def bench_fn(reps, warmup_reps, fn, *args):
570538 fn (* args )
571539
572540
573- def bench (K , dtype , tiles_per_update , reps = 1000 , warmup_reps = 10000 ):
541+ def bench (K , dtype , reps = 1000 , warmup_reps = 10000 ):
574542 M = 8192
575543 N = 8192
576544 a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
@@ -586,10 +554,10 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000):
586554 bench_fn (reps , warmup_reps , matmul_persistent , a , b .T )
587555 if supports_tma ():
588556 bench_fn (reps , warmup_reps , matmul_tma_persistent , a , b )
589- bench_fn (reps , warmup_reps , matmul_descriptor_persistent , a , b , tiles_per_update )
557+ bench_fn (reps , warmup_reps , matmul_descriptor_persistent , a , b )
590558
591559
592- def validate (M , N , K , dtype , tiles_per_update ):
560+ def validate (M , N , K , dtype ):
593561 a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
594562 b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 ).to (dtype )
595563 b = b .T .contiguous ()
@@ -599,7 +567,7 @@ def validate(M, N, K, dtype, tiles_per_update):
599567 naive_result = matmul (a , b .T )
600568 persistent_result = matmul_persistent (a , b .T )
601569 tma_persistent_result = matmul_tma_persistent (a , b ) if supports_tma () else None
602- descriptor_persistent_result = matmul_descriptor_persistent (a , b , tiles_per_update ) if supports_tma () else None
570+ descriptor_persistent_result = matmul_descriptor_persistent (a , b ) if supports_tma () else None
603571
604572 if torch_result is not None :
605573 naive_vs_torch = "✅" if torch .allclose (naive_result .to (torch .float16 ), torch_result .to (torch .float16 ),
@@ -624,7 +592,7 @@ def validate(M, N, K, dtype, tiles_per_update):
624592 if tma_persistent_result is not None :
625593 print (f"TMA persistent: { naive_vs_tma_persistent } " , end = "" )
626594 if descriptor_persistent_result is not None :
627- print (f"Device TMA persistent: { naive_vs_descriptor_persistent } " , end = "" )
595+ print (f"Tensor descriptor persistent: { naive_vs_descriptor_persistent } " , end = "" )
628596 print ()
629597
630598
@@ -644,13 +612,6 @@ def show_profile(precision, profile_name):
644612 parser .add_argument ("-K" , type = int , required = False , default = 512 )
645613 parser .add_argument ("--K_range" , type = int , nargs = 2 )
646614 parser .add_argument ("--K_step" , type = int , default = 512 )
647- parser .add_argument (
648- "--tiles_per_update" ,
649- type = int ,
650- default = 1 ,
651- help =
652- "Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel" ,
653- )
654615 parser .add_argument ("--prec" , type = str , choices = ["fp8" , "fp16" ], default = "fp16" )
655616 args = parser .parse_args ()
656617
@@ -666,11 +627,11 @@ def show_profile(precision, profile_name):
666627
667628 torch .manual_seed (0 )
668629
669- validate (32 , 32 , 32 , dtype , args . tiles_per_update )
670- validate (8192 , 8192 , 512 , dtype , args . tiles_per_update )
630+ validate (32 , 32 , 32 , dtype )
631+ validate (8192 , 8192 , 512 , dtype )
671632
672633 proton .start ("matmul" , hook = "triton" )
673634 for K in range (args .K_range [0 ], args .K_range [1 ] + 1 , args .K_step ):
674- bench (K , dtype , args . tiles_per_update )
635+ bench (K , dtype )
675636 proton .finalize ()
676637 show_profile (args .prec , "matmul" )
0 commit comments