@@ -360,9 +360,9 @@ def matmul_tma_persistent(a, b):
360360
361361
362362@triton .jit (launch_metadata = _matmul_launch_metadata )
363- def matmul_kernel_device_tma_persistent (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
363+ def matmul_kernel_device_tma_persistent (workspace_ptr , #
364+ tiles_per_update : tl .constexpr , #
364365 a_ptr , b_ptr , c_ptr , #
365- ready_flag , #
366366 M , N , K , #
367367 BLOCK_SIZE_M : tl .constexpr , #
368368 BLOCK_SIZE_N : tl .constexpr , #
@@ -377,31 +377,32 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
377377 k_tiles = tl .cdiv (K , BLOCK_SIZE_K )
378378 num_tiles = num_pid_m * num_pid_n
379379
380- if start_pid == 0 :
381- tl . extra . cuda . experimental_device_tensormap_create2d ( desc_ptr = a_desc_ptr , global_address = a_ptr ,
382- load_size = [ BLOCK_SIZE_M , BLOCK_SIZE_K ], global_size = [ M , K ],
383- element_ty = a_ptr . dtype . element_ty )
384- tl . extra . cuda . experimental_device_tensormap_create2d ( desc_ptr = b_desc_ptr , global_address = b_ptr ,
385- load_size = [ BLOCK_SIZE_N , BLOCK_SIZE_K ], global_size = [ N , K ],
386- element_ty = b_ptr . dtype . element_ty )
387- tl . extra . cuda . experimental_device_tensormap_create2d ( desc_ptr = c_desc_ptr , global_address = c_ptr ,
388- load_size = [ BLOCK_SIZE_M , BLOCK_SIZE_N ], global_size = [ M , N ],
389- element_ty = c_ptr . dtype . element_ty )
390- tl . atomic_xchg ( ready_flag , 1 , sem = "release" )
391- else :
392- flag = tl .full ([], 0 , tl . int32 )
393- while flag != 1 :
394- flag = tl . atomic_add ( ready_flag , 0 , sem = "acquire" )
395- tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (a_desc_ptr )
396- tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (b_desc_ptr )
397- tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (c_desc_ptr )
380+ TMA_SIZE : tl . constexpr = 128
381+ workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE
382+ a_desc_ptr = workspace_base
383+ b_desc_ptr = workspace_base + TMA_SIZE
384+ c_desc_ptr = workspace_base + 2 * TMA_SIZE
385+
386+ tl . extra . cuda . experimental_device_tensormap_create2d ( desc_ptr = a_desc_ptr , global_address = a_ptr ,
387+ load_size = [ BLOCK_SIZE_M , BLOCK_SIZE_K ], global_size = [ M , K ] ,
388+ element_ty = a_ptr . dtype . element_ty )
389+ tl . extra . cuda . experimental_device_tensormap_create2d ( desc_ptr = b_desc_ptr , global_address = b_ptr ,
390+ load_size = [ BLOCK_SIZE_N , BLOCK_SIZE_K ], global_size = [ N , K ],
391+ element_ty = b_ptr . dtype . element_ty )
392+ tl .extra . cuda . experimental_device_tensormap_create2d ( desc_ptr = c_desc_ptr , global_address = c_ptr ,
393+ load_size = [ BLOCK_SIZE_M , BLOCK_SIZE_N ], global_size = [ M , N ],
394+ element_ty = c_ptr . dtype . element_ty )
395+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (a_desc_ptr )
396+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (b_desc_ptr )
397+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (c_desc_ptr )
398398
399399 tiles_per_SM = num_tiles // NUM_SMS
400400 if start_pid < num_tiles % NUM_SMS :
401401 tiles_per_SM += 1
402402
403403 tile_id = start_pid - NUM_SMS
404404 ki = - 1
405+ ni = - 1
405406
406407 pid_m = 0
407408 pid_n = 0
@@ -415,6 +416,27 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
415416 for _ in range (0 , k_tiles * tiles_per_SM ):
416417 ki = tl .where (ki == k_tiles - 1 , 0 , ki + 1 )
417418 if ki == 0 :
419+ ni += 1
420+
421+ # Simulate a grouped gemm
422+ if ni == tiles_per_update :
423+ tl .extra .cuda .experimental_device_tensormap_create2d (desc_ptr = a_desc_ptr , global_address = a_ptr ,
424+ load_size = [BLOCK_SIZE_M ,
425+ BLOCK_SIZE_K ], global_size = [M , K ],
426+ element_ty = a_ptr .dtype .element_ty )
427+ tl .extra .cuda .experimental_device_tensormap_create2d (desc_ptr = b_desc_ptr , global_address = b_ptr ,
428+ load_size = [BLOCK_SIZE_N ,
429+ BLOCK_SIZE_K ], global_size = [N , K ],
430+ element_ty = b_ptr .dtype .element_ty )
431+ tl .extra .cuda .experimental_device_tensormap_create2d (desc_ptr = c_desc_ptr , global_address = c_ptr ,
432+ load_size = [BLOCK_SIZE_M ,
433+ BLOCK_SIZE_N ], global_size = [M , N ],
434+ element_ty = c_ptr .dtype .element_ty )
435+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (a_desc_ptr )
436+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (b_desc_ptr )
437+ tl .extra .cuda .experimental_tensormap_fenceproxy_acquire (c_desc_ptr )
438+ ni = 0
439+
418440 tile_id += NUM_SMS
419441 group_id = tile_id // num_pid_in_group
420442 first_pid_m = group_id * GROUP_SIZE_M
@@ -435,10 +457,11 @@ def matmul_kernel_device_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
435457 c = accumulator .to (dtype )
436458
437459 tl ._experimental_descriptor_store (c_desc_ptr , c , [offs_am , offs_bn ])
460+
438461 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
439462
440463
441- def matmul_device_tma_persistent (a , b ):
464+ def matmul_device_tma_persistent (a , b , tiles_per_update ):
442465 # Autotuner does not work with TMA. Use manual config.
443466 configs = {
444467 torch .float8_e4m3fn : {
@@ -459,15 +482,15 @@ def matmul_device_tma_persistent(a, b):
459482 dtype = a .dtype
460483
461484 c = torch .zeros ((M , N ), device = a .device , dtype = dtype )
462- a_desc , b_desc , c_desc = [torch .empty (128 , dtype = torch .uint8 , device = "cuda" ) for _ in range (3 )]
463- ready_flag = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
464485 NUM_SMS = torch .cuda .get_device_properties ("cuda" ).multi_processor_count
486+ tma_size = 128
487+ workspace = torch .empty (NUM_SMS * 3 * tma_size , dtype = torch .uint8 , device = "cuda" )
465488
466489 grid = lambda META : (min (NUM_SMS , triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ])), )
467490 matmul_kernel_device_tma_persistent [grid ](
468- a_desc , b_desc , c_desc , #
491+ workspace , #
492+ tiles_per_update , #
469493 a , b , c , #
470- ready_flag , #
471494 M , N , K , #
472495 BLOCK_SIZE_M = configs [dtype ]["BLOCK_SIZE_M" ], #
473496 BLOCK_SIZE_N = configs [dtype ]["BLOCK_SIZE_N" ], #
@@ -507,7 +530,7 @@ def torch_matmul(a, b):
507530 return c
508531
509532
510- def bench (K , dtype , reps = 10 ):
533+ def bench (K , dtype , tiles_per_update , reps = 10 ):
511534 M = 8192
512535 N = 8192
513536 a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
@@ -535,14 +558,18 @@ def bench(K, dtype, reps=10):
535558 for _ in range (reps ):
536559 matmul_tma_persistent (a , b )
537560 time .sleep (0.01 )
538- for _ in range (reps ):
539- matmul_device_tma_persistent (a , b )
540- time .sleep (0.01 )
561+ flops_str = "flops8" if dtype == torch .float8_e4m3fn else "flops"
562+ with proton .scope (
563+ f"matmul_kernel_device_tma_persistent M={ M } , N={ N } , K={ K } , tiles_per_update={ tiles_per_update :02} " ,
564+ {"bytes" : a .element_size () * (M * K + N * K ), flops_str : 2. * M * N * K }):
565+ for _ in range (reps ):
566+ matmul_device_tma_persistent (a , b , tiles_per_update )
567+ time .sleep (0.01 )
541568
542569 proton .deactivate (0 )
543570
544571
545- def validate (M , N , K , dtype ):
572+ def validate (M , N , K , dtype , tiles_per_update ):
546573 a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
547574 b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 ).to (dtype )
548575 b = b .T .contiguous ()
@@ -552,7 +579,7 @@ def validate(M, N, K, dtype):
552579 naive_result = matmul (a , b .T )
553580 persistent_result = matmul_persistent (a , b .T )
554581 tma_persistent_result = matmul_tma_persistent (a , b ) if supports_tma () else None
555- device_tma_persistent_result = matmul_device_tma_persistent (a , b ) if supports_tma () else None
582+ device_tma_persistent_result = matmul_device_tma_persistent (a , b , tiles_per_update ) if supports_tma () else None
556583
557584 if torch_result is not None :
558585 naive_vs_torch = "✅" if torch .allclose (naive_result .to (torch .float16 ), torch_result .to (torch .float16 ),
@@ -586,6 +613,13 @@ def validate(M, N, K, dtype):
586613 parser .add_argument ("-K" , type = int , required = False , default = 512 )
587614 parser .add_argument ("--K_range" , type = int , nargs = 2 )
588615 parser .add_argument ("--K_step" , type = int , default = 512 )
616+ parser .add_argument (
617+ "--tiles_per_update" ,
618+ type = int ,
619+ default = 1 ,
620+ help =
621+ "Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel" ,
622+ )
589623 parser .add_argument ("--prec" , type = str , choices = ["fp8" , "fp16" ], default = "fp16" )
590624 args = parser .parse_args ()
591625
@@ -601,10 +635,10 @@ def validate(M, N, K, dtype):
601635
602636 torch .manual_seed (0 )
603637
604- validate (32 , 32 , 32 , dtype )
605- validate (8192 , 8192 , 512 , dtype )
638+ validate (32 , 32 , 32 , dtype , args . tiles_per_update )
639+ validate (8192 , 8192 , 512 , dtype , args . tiles_per_update )
606640
607641 proton .start ("matmul" , hook = "triton" )
608642 for K in range (args .K_range [0 ], args .K_range [1 ] + 1 , args .K_step ):
609- bench (K , dtype )
643+ bench (K , dtype , args . tiles_per_update )
610644 proton .finalize ()
0 commit comments