2020"""
2121
2222import argparse
23+ import itertools
2324
2425import torch
2526import triton
@@ -46,10 +47,15 @@ def supports_tma():
4647 return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 9
4748
4849
50+ def supports_ws ():
51+ return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 10
52+
53+
4954def _matmul_launch_metadata (grid , kernel , args ):
5055 ret = {}
51- M , N , K = args ["M" ], args ["N" ], args ["K" ]
52- ret ["name" ] = f"{ kernel .name } [M={ M } , N={ N } , K={ K } ]"
56+ M , N , K , WS = args ["M" ], args ["N" ], args ["K" ], args .get ("WARP_SPECIALIZE" , False )
57+ ws_str = "_ws" if WS else ""
58+ ret ["name" ] = f"{ kernel .name } { ws_str } [M={ M } , N={ N } , K={ K } ]"
5359 if "c_ptr" in args :
5460 bytes_per_elem = args ["c_ptr" ].element_size ()
5561 else :
@@ -61,6 +67,7 @@ def _matmul_launch_metadata(grid, kernel, args):
6167
6268HAS_TMA_DESC = supports_tma () and hasattr (tl , "nv_tma_desc_type" )
6369HAS_TENSOR_DESC = supports_tma () and hasattr (tl , "make_tensor_descriptor" )
70+ HAS_WARP_SPECIALIZE = supports_ws () and HAS_TENSOR_DESC
6471
6572
6673# TmaAutoTuneHelper used in htyu's PR #5622
@@ -197,17 +204,18 @@ def matmul(a, b):
197204
198205@triton .autotune (
199206 configs = matmul_get_configs (),
200- key = ["M" , "N" , "K" ],
207+ key = ["M" , "N" , "K" , "WARP_SPECIALIZE" ],
201208)
202209@triton .jit (launch_metadata = _matmul_launch_metadata )
203- def matmul_kernel_tma_ws (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
204- M , N , K , #
205- BLOCK_SIZE_M : tl .constexpr , #
206- BLOCK_SIZE_N : tl .constexpr , #
207- BLOCK_SIZE_K : tl .constexpr , #
208- GROUP_SIZE_M : tl .constexpr , #
209- FP8_OUTPUT : tl .constexpr , #
210- ):
210+ def matmul_kernel_tma (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
211+ M , N , K , #
212+ BLOCK_SIZE_M : tl .constexpr , #
213+ BLOCK_SIZE_N : tl .constexpr , #
214+ BLOCK_SIZE_K : tl .constexpr , #
215+ GROUP_SIZE_M : tl .constexpr , #
216+ FP8_OUTPUT : tl .constexpr , #
217+ WARP_SPECIALIZE : tl .constexpr , #
218+ ):
211219 dtype = tl .float8e4nv if FP8_OUTPUT else tl .float16
212220
213221 pid = tl .program_id (axis = 0 )
@@ -227,7 +235,7 @@ def matmul_kernel_tma_ws(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
227235
228236 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
229237
230- for k in tl .range (k_tiles , warp_specialize = True , num_stages = 3 ):
238+ for k in tl .range (k_tiles , warp_specialize = WARP_SPECIALIZE ):
231239 offs_k = k * BLOCK_SIZE_K
232240 a = tl ._experimental_descriptor_load (a_desc_ptr , [offs_am , offs_k ], [BLOCK_SIZE_M , BLOCK_SIZE_K ], dtype )
233241 b = tl ._experimental_descriptor_load (b_desc_ptr , [offs_bn , offs_k ], [BLOCK_SIZE_N , BLOCK_SIZE_K ], dtype )
@@ -240,7 +248,7 @@ def matmul_kernel_tma_ws(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
240248 tl ._experimental_descriptor_store (c_desc_ptr , c , [offs_cm , offs_cn ])
241249
242250
243- def matmul_tma_ws (a , b ):
251+ def matmul_tma (a , b , warp_specialize : bool ):
244252 # Check constraints.
245253 assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
246254 assert a .dtype == b .dtype , "Incompatible dtypes"
@@ -296,10 +304,11 @@ def grid(META):
296304 desc_b = desc_helper .get_tma_descriptor_kernel_param ("b" )
297305 desc_c = desc_helper .get_tma_descriptor_kernel_param ("c" )
298306
299- matmul_kernel_tma_ws [grid ](
307+ matmul_kernel_tma [grid ](
300308 desc_a , desc_b , desc_c , #
301309 M , N , K , #
302310 FP8_OUTPUT = dtype == torch .float8_e4m3fn , #
311+ WARP_SPECIALIZE = warp_specialize , #
303312 )
304313 return c
305314
@@ -402,19 +411,23 @@ def matmul_persistent(a, b):
402411
403412def matmul_tma_persistent_get_configs ():
404413 return [
405- triton .Config ({'BLOCK_SIZE_M' : BM , 'BLOCK_SIZE_N' : BN , "BLOCK_SIZE_K" : BK , "GROUP_SIZE_M" : 8 , "EPILOGUE_SUBTILE" : SUBTILE }, num_stages = s , num_warps = w ) \
406- for BM in [128 ] \
407- for BN in [128 , 256 ] \
408- for BK in [64 , 128 ] \
409- for s in ([2 , 3 , 4 ]) \
410- for w in [4 , 8 ] \
411- for SUBTILE in [True , False ] \
414+ triton .Config (
415+ {
416+ 'BLOCK_SIZE_M' : BM , 'BLOCK_SIZE_N' : BN , "BLOCK_SIZE_K" : BK , "GROUP_SIZE_M" : 8 , "EPILOGUE_SUBTILE" :
417+ SUBTILE
418+ }, num_stages = s , num_warps = w ) #
419+ for BM in [128 ] #
420+ for BN in [128 , 256 ] #
421+ for BK in [64 , 128 ] #
422+ for s in ([2 , 3 , 4 ]) #
423+ for w in [4 , 8 ] #
424+ for SUBTILE in [True , False ] #
412425 ]
413426
414427
415428@triton .autotune (
416429 configs = matmul_tma_persistent_get_configs (),
417- key = ["M" , "N" , "K" ],
430+ key = ["M" , "N" , "K" , "WARP_SPECIALIZE" ],
418431)
419432@triton .jit (launch_metadata = _matmul_launch_metadata )
420433def matmul_kernel_tma_persistent (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
@@ -425,7 +438,9 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
425438 GROUP_SIZE_M : tl .constexpr , #
426439 FP8_OUTPUT : tl .constexpr , #
427440 EPILOGUE_SUBTILE : tl .constexpr , #
428- NUM_SMS : tl .constexpr ): #
441+ NUM_SMS : tl .constexpr , #
442+ WARP_SPECIALIZE : tl .constexpr , #
443+ ):
429444 dtype = tl .float8e4nv if FP8_OUTPUT else tl .float16
430445 start_pid = tl .program_id (axis = 0 )
431446 num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
@@ -439,7 +454,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
439454 # Enable warp specialization to leverage async warp scheduling in the GPU.
440455 # FIXME: This only works on Blackwell right now. On older GPUs, this will
441456 # use software pipelining.
442- for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = True ):
457+ for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE ):
443458 pid_m , pid_n = _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS )
444459 offs_am = pid_m * BLOCK_SIZE_M
445460 offs_bn = pid_n * BLOCK_SIZE_N
@@ -473,7 +488,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
473488 tl ._experimental_descriptor_store (c_desc_ptr , accumulator , [offs_am_c , offs_bn_c ])
474489
475490
476- def matmul_tma_persistent (a , b ):
491+ def matmul_tma_persistent (a , b , warp_specialize : bool ):
477492 # Check constraints.
478493 assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
479494 assert a .dtype == b .dtype , "Incompatible dtypes"
@@ -542,13 +557,14 @@ def grid(META):
542557 M , N , K , #
543558 FP8_OUTPUT = dtype == torch .float8_e4m3fn , #
544559 NUM_SMS = NUM_SMS , #
560+ WARP_SPECIALIZE = warp_specialize , #
545561 )
546562 return c
547563
548564
549565@triton .autotune (
550566 configs = matmul_tma_persistent_get_configs (),
551- key = ["M" , "N" , "K" ],
567+ key = ["M" , "N" , "K" , "WARP_SPECIALIZE" ],
552568)
553569@triton .jit (launch_metadata = _matmul_launch_metadata )
554570def matmul_kernel_descriptor_persistent (a_ptr , b_ptr , c_ptr , #
@@ -558,7 +574,9 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
558574 BLOCK_SIZE_K : tl .constexpr , #
559575 GROUP_SIZE_M : tl .constexpr , #
560576 EPILOGUE_SUBTILE : tl .constexpr , #
561- NUM_SMS : tl .constexpr ): #
577+ NUM_SMS : tl .constexpr , #
578+ WARP_SPECIALIZE : tl .constexpr , #
579+ ):
562580 # Matmul using TMA and device-side descriptor creation
563581 dtype = c_ptr .dtype .element_ty
564582 start_pid = tl .program_id (axis = 0 )
@@ -591,7 +609,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
591609 tile_id_c = start_pid - NUM_SMS
592610 num_pid_in_group = GROUP_SIZE_M * num_pid_n
593611
594- for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True ):
612+ for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE ):
595613 pid_m , pid_n = _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS )
596614 offs_am = pid_m * BLOCK_SIZE_M
597615 offs_bn = pid_n * BLOCK_SIZE_N
@@ -621,7 +639,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
621639 c_desc .store ([offs_cm , offs_cn ], c )
622640
623641
624- def matmul_descriptor_persistent (a , b ):
642+ def matmul_descriptor_persistent (a , b , warp_specialize : bool ):
625643 # Check constraints.
626644 assert a .shape [1 ] == b .shape [1 ], "Incompatible dimensions" # b is transposed
627645 assert a .dtype == b .dtype , "Incompatible dtypes"
@@ -644,6 +662,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
644662 a , b , c , #
645663 M , N , K , #
646664 NUM_SMS = NUM_SMS , #
665+ WARP_SPECIALIZE = warp_specialize , #
647666 )
648667 return c
649668
@@ -683,12 +702,14 @@ def proton_context():
683702 proton .deactivate (0 )
684703
685704
686- def bench_fn (reps , warmup_reps , fn , * args ):
705+ def bench_fn (label , reps , warmup_reps , fn , * args ):
706+ print (f"Benchmarking { label } : ..." , end = "" )
687707 for _ in range (warmup_reps ):
688708 fn (* args )
689709 with proton_context ():
690710 for _ in range (reps ):
691711 fn (* args )
712+ print (f"\r Benchmarking { label } : done" )
692713
693714
694715def bench (K , dtype , reps = 10000 , warmup_reps = 10000 ):
@@ -700,60 +721,55 @@ def bench(K, dtype, reps=10000, warmup_reps=10000):
700721 b = b .T .contiguous ()
701722
702723 if cublas is not None :
703- bench_fn (reps , warmup_reps , cublas_matmul , a , b )
724+ bench_fn ("cublas" , reps , 1 , cublas_matmul , a , b )
704725 if dtype == torch .float16 :
705- bench_fn (reps , warmup_reps , torch_matmul , a , b )
706- bench_fn (reps , warmup_reps , matmul , a , b .T )
707- bench_fn (reps , warmup_reps , matmul_persistent , a , b .T )
708- if HAS_TMA_DESC :
709- bench_fn (reps , warmup_reps , matmul_tma_persistent , a , b )
710- if HAS_TENSOR_DESC :
711- bench_fn (reps , warmup_reps , matmul_descriptor_persistent , a , b )
712- bench_fn (reps , warmup_reps , matmul_tma_ws , a , b )
726+ bench_fn ("torch" , reps , warmup_reps , torch_matmul , a , b )
727+ bench_fn ("naive" , reps , warmup_reps , matmul , a , b .T )
728+ bench_fn ("persistent" , reps , warmup_reps , matmul_persistent , a , b .T )
729+ warp_specialize = [False , True ] if HAS_WARP_SPECIALIZE else [False ]
730+ for ws in warp_specialize :
731+ ws_str = "_ws" if ws else ""
732+ if HAS_TMA_DESC :
733+ bench_fn (f"tma_persistent{ ws_str } " , reps , warmup_reps , lambda a , b : matmul_tma_persistent (a , b , ws ), a , b )
734+ bench_fn (f"tma{ ws_str } " , reps , warmup_reps , lambda a , b : matmul_tma (a , b , ws ), a , b )
735+ if HAS_TENSOR_DESC :
736+ bench_fn (f"descriptor_persistent{ ws_str } " , reps , warmup_reps ,
737+ lambda a , b : matmul_descriptor_persistent (a , b , ws ), a , b )
738+
739+
740+ def run_test (expect , fn , a , b , label , enabled = True ):
741+ print (f" { label } : ..." , end = "" )
742+ if enabled :
743+ actual = fn (a , b )
744+ passed = torch .allclose (expect , actual .to (expect .dtype ), atol = 1.0 )
745+ icon = "✅" if passed else "❌"
746+ else :
747+ icon = "⭕"
748+ print (f"\r { label } : { icon } " )
713749
714750
715751def validate (M , N , K , dtype ):
752+ print (f"{ M = } , { N = } , { K = } , verification naive vs: " )
716753 a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
717754 b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 ).to (dtype )
718755 b = b .T .contiguous ()
719756
720- torch_result = torch_matmul (a , b ) if dtype == torch .float16 else None
721- cublas_result = cublas_matmul (a , b ) if cublas is not None else None
722- naive_result = matmul (a , b .T )
723- tma_ws_result = matmul_tma_ws (a , b ) if HAS_TENSOR_DESC else None
724- persistent_result = matmul_persistent (a , b .T )
725- tma_persistent_result = matmul_tma_persistent (a , b ) if HAS_TMA_DESC else None
726- descriptor_persistent_result = matmul_descriptor_persistent (a , b ) if HAS_TENSOR_DESC else None
727-
728- if tma_ws_result is not None :
729- naive_vs_tma_ws = "✅" if torch .allclose (naive_result .to (torch .float16 ), tma_ws_result .to (torch .float16 ),
730- atol = 1.0 ) else "❌"
731- if torch_result is not None :
732- naive_vs_torch = "✅" if torch .allclose (naive_result .to (torch .float16 ), torch_result .to (torch .float16 ),
733- atol = 1.0 ) else "❌"
734- if cublas_result is not None :
735- naive_vs_cublas = "✅" if torch .allclose (naive_result .to (torch .float16 ), cublas_result .to (torch .float16 ),
736- atol = 1.0 ) else "❌"
737- naive_vs_persistent = "✅" if torch .allclose (naive_result .to (torch .float16 ), persistent_result .to (torch .float16 ),
738- atol = 1.0 ) else "❌"
739- if tma_persistent_result is not None :
740- naive_vs_tma_persistent = "✅" if torch .allclose (cublas_result .to (torch .float16 ),
741- tma_persistent_result .to (torch .float16 ), atol = 1.0 ) else "❌"
742- if descriptor_persistent_result is not None :
743- naive_vs_descriptor_persistent = "✅" if torch .allclose (cublas_result .to (
744- torch .float16 ), descriptor_persistent_result .to (torch .float16 ), atol = 1.0 ) else "❌"
745- print (f"M={ M } , N={ N } , K={ K } verification naive vs: " , end = "" )
746- if tma_ws_result is not None :
747- print (f"tma: { naive_vs_tma_ws } " , end = "" )
748- if torch_result is not None :
749- print (f"torch: { naive_vs_torch } " , end = "" )
750- if cublas_result is not None :
751- print (f"cublas: { naive_vs_cublas } " , end = "" )
752- print (f"persistent: { naive_vs_persistent } " , end = "" )
753- if tma_persistent_result is not None :
754- print (f"TMA persistent: { naive_vs_tma_persistent } " , end = "" )
755- if descriptor_persistent_result is not None :
756- print (f"Tensor descriptor persistent: { naive_vs_descriptor_persistent } " , end = "" )
757+ naive_result = matmul (a , b .T ).to (torch .float16 )
758+ run_test (naive_result , torch_matmul , a , b , "Torch" , enabled = dtype == torch .float16 )
759+ run_test (naive_result , cublas_matmul , a , b , "cuBLAS" , enabled = cublas is not None )
760+ run_test (naive_result , matmul_persistent , a , b .T , "Persistent" )
761+
762+ kernels = [
763+ (matmul_tma , "TMA" , HAS_TMA_DESC ),
764+ (matmul_tma_persistent , "TMA Persistent" , HAS_TMA_DESC ),
765+ (matmul_descriptor_persistent , "Tensor Descriptor Persistent" , HAS_TENSOR_DESC ),
766+ ]
767+ warp_specialize = [False , True ] if HAS_WARP_SPECIALIZE else [False ]
768+
769+ for (kernel , label , enabled ), warp_specialize in itertools .product (kernels , warp_specialize ):
770+ label = f"{ label } (warp_specialize={ warp_specialize } )"
771+ enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC )
772+ run_test (naive_result , lambda a , b : kernel (a , b , warp_specialize ), a , b , label , enabled )
757773 print ()
758774
759775
0 commit comments