@@ -143,7 +143,6 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
143
143
cuMemcpyHtoD(B, hB, K*N*2);
144
144
145
145
// launch kernel
146
- cuStreamSynchronize(stream);
147
146
CUresult ret;
148
147
int algo_id = { algo_id } ;
149
148
if (algo_id == 0) {{
@@ -154,8 +153,6 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
154
153
if (ret != 0) fprintf(stderr, "kernel launch failed\\ n");
155
154
assert(ret == 0);
156
155
157
- cuStreamSynchronize(stream);
158
-
159
156
// read data
160
157
int32_t hC[M*N];
161
158
memset(hC, 0, M*N*4);
@@ -241,21 +238,20 @@ def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
241
238
242
239
def compile_aot_kernels (dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints ):
243
240
# compile all desired configs
244
- for ha in ha_hb_hints :
245
- for hb in ha_hb_hints :
246
- sig = f"*fp32:16, *{ dtype } :16, *{ dtype } :16, i32, i32, i32, i32{ ha } , i32:1, i32{ hb } , i32:1, i32:16, i32:1, { BM } , { BN } , { BK } "
247
- name = f"matmul_{ dtype } "
248
- grid = f"M/{ BM } , N/{ BN } , 1"
249
- _compile_kernel (
250
- dir = dir ,
251
- signature = sig ,
252
- kernel_name = "kernel" ,
253
- out_name = name ,
254
- out_path = name ,
255
- num_warps = 1 ,
256
- grid = grid ,
257
- kernel_path = kernel_path ,
258
- )
241
+ for ha , hb in ha_hb_hints :
242
+ sig = f"*fp32:16, *{ dtype } :16, *{ dtype } :16, i32, i32, i32, i32{ ha } , i32:1, i32{ hb } , i32:1, i32:16, i32:1, { BM } , { BN } , { BK } "
243
+ name = f"matmul_{ dtype } "
244
+ grid = f"M/{ BM } , N/{ BN } , 1"
245
+ _compile_kernel (
246
+ dir = dir ,
247
+ signature = sig ,
248
+ kernel_name = "kernel" ,
249
+ out_name = name ,
250
+ out_path = name ,
251
+ num_warps = 1 ,
252
+ grid = grid ,
253
+ kernel_path = kernel_path ,
254
+ )
259
255
260
256
261
257
def link_aot_kernels (dir ):
@@ -317,7 +313,7 @@ def test_compile_link_matmul():
317
313
BM , BN , BK = 16 , 16 , 16
318
314
319
315
kernel_path = write_triton_kernels (tmp_dir , kernel_src , kernel_utils_src )
320
- compile_aot_kernels (tmp_dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints = [" " , ":16" ])
316
+ compile_aot_kernels (tmp_dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints = [( ":16 " , ":16" ) ])
321
317
link_aot_kernels (tmp_dir )
322
318
323
319
# compile test case
@@ -348,7 +344,7 @@ def test_launcher_has_no_available_kernel():
348
344
BM , BN , BK = 16 , 16 , 16
349
345
350
346
kernel_path = write_triton_kernels (tmp_dir , kernel_src , kernel_utils_src )
351
- compile_aot_kernels (tmp_dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints = [":1" ])
347
+ compile_aot_kernels (tmp_dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints = [( ":1" , ":1" ) ])
352
348
link_aot_kernels (tmp_dir )
353
349
354
350
# compile test case
@@ -385,14 +381,13 @@ def test_compile_link_autotune_matmul():
385
381
386
382
tile_sizes = [
387
383
[16 , 16 , 16 ],
388
- [32 , 32 , 16 ],
389
- [32 , 32 , 32 ],
390
384
[64 , 64 , 32 ],
391
385
]
392
386
393
387
for ts in tile_sizes :
394
388
BM , BN , BK = ts [0 ], ts [1 ], ts [2 ]
395
- compile_aot_kernels (tmp_dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints = ["" , ":16" ])
389
+ compile_aot_kernels (tmp_dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints = [(":16" , ":16" ), (":16" , "" ),
390
+ ("" , ":16" )])
396
391
397
392
link_aot_kernels (tmp_dir )
398
393
0 commit comments