Skip to content

Commit c6d9624

Browse files
authored
[TESTS] Reduce AOT test time (#7068)
Remove duplicate tests and unnecessary synchronization. On some platforms it reduces the time from 50s to 20s
1 parent 4ff4bd0 commit c6d9624

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

python/test/unit/tools/test_aot.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
143143
cuMemcpyHtoD(B, hB, K*N*2);
144144
145145
// launch kernel
146-
cuStreamSynchronize(stream);
147146
CUresult ret;
148147
int algo_id = {algo_id};
149148
if (algo_id == 0) {{
@@ -154,8 +153,6 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
154153
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
155154
assert(ret == 0);
156155
157-
cuStreamSynchronize(stream);
158-
159156
// read data
160157
int32_t hC[M*N];
161158
memset(hC, 0, M*N*4);
@@ -241,21 +238,20 @@ def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
241238

242239
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
243240
# compile all desired configs
244-
for ha in ha_hb_hints:
245-
for hb in ha_hb_hints:
246-
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
247-
name = f"matmul_{dtype}"
248-
grid = f"M/{BM}, N/{BN}, 1"
249-
_compile_kernel(
250-
dir=dir,
251-
signature=sig,
252-
kernel_name="kernel",
253-
out_name=name,
254-
out_path=name,
255-
num_warps=1,
256-
grid=grid,
257-
kernel_path=kernel_path,
258-
)
241+
for ha, hb in ha_hb_hints:
242+
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
243+
name = f"matmul_{dtype}"
244+
grid = f"M/{BM}, N/{BN}, 1"
245+
_compile_kernel(
246+
dir=dir,
247+
signature=sig,
248+
kernel_name="kernel",
249+
out_name=name,
250+
out_path=name,
251+
num_warps=1,
252+
grid=grid,
253+
kernel_path=kernel_path,
254+
)
259255

260256

261257
def link_aot_kernels(dir):
@@ -317,7 +313,7 @@ def test_compile_link_matmul():
317313
BM, BN, BK = 16, 16, 16
318314

319315
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
320-
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
316+
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":16", ":16")])
321317
link_aot_kernels(tmp_dir)
322318

323319
# compile test case
@@ -348,7 +344,7 @@ def test_launcher_has_no_available_kernel():
348344
BM, BN, BK = 16, 16, 16
349345

350346
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
351-
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"])
347+
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":1", ":1")])
352348
link_aot_kernels(tmp_dir)
353349

354350
# compile test case
@@ -385,14 +381,13 @@ def test_compile_link_autotune_matmul():
385381

386382
tile_sizes = [
387383
[16, 16, 16],
388-
[32, 32, 16],
389-
[32, 32, 32],
390384
[64, 64, 32],
391385
]
392386

393387
for ts in tile_sizes:
394388
BM, BN, BK = ts[0], ts[1], ts[2]
395-
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"])
389+
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":16", ":16"), (":16", ""),
390+
("", ":16")])
396391

397392
link_aot_kernels(tmp_dir)
398393

0 commit comments

Comments
 (0)