@@ -61,12 +61,8 @@ def _matmul_launch_metadata(grid, kernel, args):
6161 return ret
6262
6363
64- HAS_TMA_DESC = "nv_tma_desc_type" in dir (tl )
65-
66- if HAS_TMA_DESC :
67- print ("TMA benchmarks will be running with experimental grid constant TMA descriptor." , )
68- else :
69- print ("TMA benchmarks will be running without grid constant TMA descriptor." , )
64+ HAS_TMA_DESC = supports_tma () and hasattr (tl , "nv_tma_desc_type" )
65+ HAS_TENSOR_DESC = supports_tma () and hasattr (tl , "make_tensor_descriptor" )
7066
7167
7268# TmaAutoTuneHelper used in htyu's PR #5622
@@ -86,49 +82,27 @@ def tma_desc_cpu_ptr(self):
8682 def __init__ (self ):
8783 self .fill_1d_tma_descriptor_inner = (triton .runtime .driver .active .utils .fill_1d_tma_descriptor )
8884 self .fill_2d_tma_descriptor_inner = (triton .runtime .driver .active .utils .fill_2d_tma_descriptor )
89- if HAS_TMA_DESC :
90- self .descriptors = {}
91- else :
92- self .cuda_descriptors = {}
85+ self .descriptors = {}
9386
9487 # Call this method outside of the lambda function for grid size
9588 def init_tma_descriptor (self , name ):
96- if HAS_TMA_DESC :
97- self .descriptors [name ] = torch .empty (TmaAutoTuneHelper .TMA_SIZE , device = "cpu" , dtype = torch .int8 )
98- else :
99- self .cuda_descriptors [name ] = torch .empty (TmaAutoTuneHelper .TMA_SIZE , device = "cuda" , dtype = torch .int8 )
89+ self .descriptors [name ] = torch .empty (TmaAutoTuneHelper .TMA_SIZE , device = "cpu" , dtype = torch .int8 )
10090
10191 # Call this method inside the lambda function for grid size
10292 def fill_1d_tma_descriptor (self , name , ptr , dim , block_dim , element_size ):
103- if HAS_TMA_DESC :
104- desc_x = self .descriptors [name ]
105- assert desc_x .data_ptr () % 64 == 0
106- self .fill_1d_tma_descriptor_inner (ptr , dim , block_dim , element_size , desc_x .data_ptr ())
107- else :
108- desc_x = self .cuda_descriptors [name ]
109- buf_x = torch .empty_like (desc_x , device = "cpu" , pin_memory = True )
110- self .fill_1d_tma_descriptor_inner (ptr , dim , block_dim , element_size , buf_x .data_ptr ())
111- desc_x .copy_ (buf_x , non_blocking = True )
93+ desc_x = self .descriptors [name ]
94+ assert desc_x .data_ptr () % 64 == 0
95+ self .fill_1d_tma_descriptor_inner (ptr , dim , block_dim , element_size , desc_x .data_ptr ())
11296
11397 # Call this method inside the lambda function for grid size
11498 def fill_2d_tma_descriptor (self , name , ptr , dim1 , dim0 , block_dim1 , block_dim0 , element_size ):
115- if HAS_TMA_DESC :
116- desc_x = self .descriptors [name ]
117- assert desc_x .data_ptr () % 64 == 0
118- self .fill_2d_tma_descriptor_inner (ptr , dim1 , dim0 , block_dim1 , block_dim0 , element_size , desc_x .data_ptr ())
119- else :
120- desc_x = self .cuda_descriptors [name ]
121- buf_x = torch .empty_like (desc_x , device = "cpu" , pin_memory = True )
122- self .fill_2d_tma_descriptor_inner (ptr , dim1 , dim0 , block_dim1 , block_dim0 , element_size , buf_x .data_ptr ())
123- desc_x .copy_ (buf_x , non_blocking = True )
99+ desc_x = self .descriptors [name ]
100+ assert desc_x .data_ptr () % 64 == 0
101+ self .fill_2d_tma_descriptor_inner (ptr , dim1 , dim0 , block_dim1 , block_dim0 , element_size , desc_x .data_ptr ())
124102
125103 def get_tma_descriptor_kernel_param (self , name ):
126- if HAS_TMA_DESC :
127- assert self .descriptors [name ] is not None
128- return self .KernelParamWrapper (self .descriptors [name ])
129- else :
130- assert self .cuda_descriptors [name ] is not None
131- return self .cuda_descriptors [name ]
104+ assert self .descriptors [name ] is not None
105+ return self .KernelParamWrapper (self .descriptors [name ])
132106
133107
134108def matmul_get_configs ():
@@ -228,7 +202,7 @@ def matmul(a, b):
228202 key = ["M" , "N" , "K" ],
229203)
230204@triton .jit (launch_metadata = _matmul_launch_metadata )
231- def matmul_tma_ws_kernel (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
205+ def matmul_kernel_tma_ws (a_desc_ptr , b_desc_ptr , c_desc_ptr , #
232206 M , N , K , #
233207 BLOCK_SIZE_M : tl .constexpr , #
234208 BLOCK_SIZE_N : tl .constexpr , #
@@ -321,7 +295,7 @@ def grid(META):
321295 desc_b = desc_helper .get_tma_descriptor_kernel_param ("b" )
322296 desc_c = desc_helper .get_tma_descriptor_kernel_param ("c" )
323297
324- matmul_tma_ws_kernel [grid ](
298+ matmul_kernel_tma_ws [grid ](
325299 desc_a , desc_b , desc_c , #
326300 M , N , K , #
327301 )
@@ -726,10 +700,11 @@ def bench(K, dtype, reps=1000, warmup_reps=10000):
726700 bench_fn (reps , warmup_reps , torch_matmul , a , b )
727701 bench_fn (reps , warmup_reps , matmul , a , b .T )
728702 bench_fn (reps , warmup_reps , matmul_persistent , a , b .T )
729- if supports_tma ():
730- bench_fn (reps , warmup_reps , matmul_tma_ws , a , b )
703+ if HAS_TMA_DESC :
731704 bench_fn (reps , warmup_reps , matmul_tma_persistent , a , b )
705+ if HAS_TENSOR_DESC :
732706 bench_fn (reps , warmup_reps , matmul_descriptor_persistent , a , b )
707+ bench_fn (reps , warmup_reps , matmul_tma_ws , a , b )
733708
734709
735710def validate (M , N , K , dtype ):
@@ -740,10 +715,10 @@ def validate(M, N, K, dtype):
740715 torch_result = torch_matmul (a , b ) if dtype == torch .float16 else None
741716 cublas_result = cublas_matmul (a , b ) if cublas is not None else None
742717 naive_result = matmul (a , b .T )
743- tma_ws_result = matmul_tma_ws (a , b ) if supports_tma () else None
718+ tma_ws_result = matmul_tma_ws (a , b ) if HAS_TENSOR_DESC else None
744719 persistent_result = matmul_persistent (a , b .T )
745- tma_persistent_result = matmul_tma_persistent (a , b ) if supports_tma () else None
746- descriptor_persistent_result = matmul_descriptor_persistent (a , b ) if supports_tma () else None
720+ tma_persistent_result = matmul_tma_persistent (a , b ) if HAS_TMA_DESC else None
721+ descriptor_persistent_result = matmul_descriptor_persistent (a , b ) if HAS_TENSOR_DESC else None
747722
748723 if tma_ws_result is not None :
749724 naive_vs_tma_ws = "✅" if torch .allclose (naive_result .to (torch .float16 ), tma_ws_result .to (torch .float16 ),
0 commit comments