@@ -47,8 +47,12 @@ def supports_tma():
4747 return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 9
4848
4949
50+ def is_hopper ():
51+ return torch .cuda .get_device_capability ()[0 ] == 9
52+
53+
5054def supports_ws ():
51- return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 10
55+ return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 9
5256
5357
5458def _matmul_launch_metadata (grid , kernel , args ):
@@ -465,21 +469,31 @@ def grid(META):
465469 return c
466470
467471
468- @triton .autotune (
469- configs = matmul_tma_persistent_get_configs (),
470- key = ["M" , "N" , "K" , "WARP_SPECIALIZE" ],
471- )
472+ def prune_invalid_configs (configs , named_args , ** kwargs ):
473+ FLATTEN = kwargs ["FLATTEN" ]
474+ # Filter out configs where EPILOGUE_SUBTILE is true and HOPPER is true
475+ return [conf for conf in configs if not (conf .kwargs .get ("EPILOGUE_SUBTILE" , True ) and FLATTEN is False )]
476+
477+
478+ @triton .autotune (configs = matmul_tma_persistent_get_configs (), key = ["M" , "N" , "K" , "WARP_SPECIALIZE" , "FLATTEN" ],
479+ prune_configs_by = {'early_config_prune' : prune_invalid_configs })
472480@triton .jit (launch_metadata = _matmul_launch_metadata )
473- def matmul_kernel_descriptor_persistent (a_ptr , b_ptr , c_ptr , #
474- M , N , K , #
475- BLOCK_SIZE_M : tl .constexpr , #
476- BLOCK_SIZE_N : tl .constexpr , #
477- BLOCK_SIZE_K : tl .constexpr , #
478- GROUP_SIZE_M : tl .constexpr , #
479- EPILOGUE_SUBTILE : tl .constexpr , #
480- NUM_SMS : tl .constexpr , #
481- WARP_SPECIALIZE : tl .constexpr , #
482- ):
481+ def matmul_kernel_descriptor_persistent (
482+ a_ptr ,
483+ b_ptr ,
484+ c_ptr , #
485+ M ,
486+ N ,
487+ K , #
488+ BLOCK_SIZE_M : tl .constexpr , #
489+ BLOCK_SIZE_N : tl .constexpr , #
490+ BLOCK_SIZE_K : tl .constexpr , #
491+ GROUP_SIZE_M : tl .constexpr , #
492+ EPILOGUE_SUBTILE : tl .constexpr , #
493+ NUM_SMS : tl .constexpr , #
494+ WARP_SPECIALIZE : tl .constexpr , #
495+ FLATTEN : tl .constexpr ,
496+ ):
483497 # Matmul using TMA and device-side descriptor creation
484498 dtype = c_ptr .dtype .element_ty
485499 start_pid = tl .program_id (axis = 0 )
@@ -512,7 +526,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
512526 tile_id_c = start_pid - NUM_SMS
513527 num_pid_in_group = GROUP_SIZE_M * num_pid_n
514528
515- for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE ):
529+ for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = FLATTEN , warp_specialize = WARP_SPECIALIZE ):
516530 pid_m , pid_n = _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS )
517531 offs_am = pid_m * BLOCK_SIZE_M
518532 offs_bn = pid_n * BLOCK_SIZE_N
@@ -560,12 +574,19 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
560574
561575 triton .set_allocator (alloc_fn )
562576
577+ # Hopper warpspec doesn't work with flatten
578+ flatten = False if (warp_specialize and is_hopper ()) else True
563579 grid = lambda META : (min (NUM_SMS , triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ])), )
564580 matmul_kernel_descriptor_persistent [grid ](
565- a , b , c , #
566- M , N , K , #
581+ a ,
582+ b ,
583+ c , #
584+ M ,
585+ N ,
586+ K , #
567587 NUM_SMS = NUM_SMS , #
568588 WARP_SPECIALIZE = warp_specialize , #
589+ FLATTEN = flatten ,
569590 )
570591 return c
571592
@@ -632,7 +653,8 @@ def bench(K, dtype, reps=10000, warmup_reps=10000):
632653 warp_specialize = [False , True ] if HAS_WARP_SPECIALIZE else [False ]
633654 for ws in warp_specialize :
634655 ws_str = "_ws" if ws else ""
635- if HAS_HOST_TENSOR_DESC :
656+ # disable on-host warpspec on Hopper
657+ if HAS_HOST_TENSOR_DESC and not (is_hopper () and ws ):
636658 bench_fn (f"tma_persistent{ ws_str } " , reps , warmup_reps , lambda a , b : matmul_tma_persistent (a , b , ws ), a , b )
637659 bench_fn (f"tma{ ws_str } " , reps , warmup_reps , lambda a , b : matmul_tma (a , b , ws ), a , b )
638660 if HAS_TENSOR_DESC :
@@ -671,7 +693,9 @@ def validate(M, N, K, dtype):
671693
672694 for (kernel , label , enabled ), warp_specialize in itertools .product (kernels , warp_specialize ):
673695 label = f"{ label } (warp_specialize={ warp_specialize } )"
674- enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC )
696+ # skip if hopper and warp_specialize and not on-device
697+ skipped = is_hopper () and warp_specialize and kernel != matmul_descriptor_persistent
698+ enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC ) and (not skipped )
675699 run_test (naive_result , lambda a , b : kernel (a , b , warp_specialize ), a , b , label , enabled )
676700 print ()
677701
0 commit comments