@@ -68,6 +68,50 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
6868 return default_test_shapes [test_name ]
6969
7070
71+ @require_e2e
72+ def testGemmBench (tmp_path ):
73+ shape = (64 , 64 , 64 )
74+ perf_filename_tk = tmp_path / "wave_gemm_bench.txt"
75+ perf_filename_iree = tmp_path / "iree_gemm_bench.txt"
76+ enable_scheduling = SchedulingType .NONE
77+ dynamic_dims = False
78+ mfma_variant = MMAType .F32_16x16x16_F16
79+ gemm , hyperparams , dynamic_symbols = get_gemm_kernel (
80+ shape , dynamic_dims , mfma_variant , torch .float16
81+ )
82+
83+ assert not perf_filename_tk .exists ()
84+
85+ options = WaveCompileOptions (
86+ subs = hyperparams ,
87+ canonicalize = True ,
88+ run_bench = True ,
89+ schedule = enable_scheduling ,
90+ use_scheduling_barriers = enable_scheduling_barriers ,
91+ dynamic_symbols = dynamic_symbols ,
92+ benchmark_batch_size = 10 ,
93+ benchmark_repetitions = 3 ,
94+ benchmark_results_file = perf_filename_tk ,
95+ )
96+ options = set_default_run_config (options )
97+ gemm = wave_compile (options , gemm )
98+
99+ a = device_randn (shape [0 ], shape [2 ], dtype = torch .float16 )
100+ b = device_randn (shape [1 ], shape [2 ], dtype = torch .float16 )
101+ c = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
102+ gemm (a , b , c )
103+ assert perf_filename_tk .exists ()
104+ assert "real_time" in perf_filename_tk .read_text ()
105+
106+ assert not perf_filename_iree .exists ()
107+ options .benchmark_results_file = perf_filename_iree
108+
109+ iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
110+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
111+ assert perf_filename_iree .exists ()
112+ assert "real_time" in perf_filename_iree .read_text ()
113+
114+
71115@require_e2e
72116@pytest .mark .parametrize ("shape" , get_test_shapes ("test_gemm" ))
73117@pytest .mark .parametrize (
@@ -130,7 +174,7 @@ def testPureGemm(
130174 options .benchmark_results_file = perf_filename_iree
131175
132176 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
133- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
177+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
134178 assert_close (c , iree_ref , check_device = False )
135179
136180
@@ -202,7 +246,7 @@ def testGemmGatherToLDS(
202246 options .benchmark_results_file = perf_filename_iree
203247
204248 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
205- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
249+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
206250 assert_close (c , iree_ref , check_device = False )
207251
208252
@@ -336,7 +380,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
336380 options .benchmark_results_file = perf_filename_iree
337381
338382 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
339- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
383+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
340384 assert_close (c , iree_ref , check_device = False )
341385
342386
@@ -574,7 +618,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
574618 options .benchmark_results_file = perf_filename_iree
575619
576620 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
577- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
621+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
578622 assert_close (c , iree_ref , check_device = False )
579623
580624
@@ -627,7 +671,7 @@ def testGemmDumpOverrideSchedule(
627671 options .benchmark_results_file = perf_filename_iree
628672
629673 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
630- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
674+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
631675 assert_close (c , iree_ref , check_device = False )
632676
633677 # Now reload the schedule and run the kernel again.
@@ -784,7 +828,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
784828 options .benchmark_results_file = perf_filename_iree
785829
786830 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
787- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
831+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
788832 assert_close (c , iree_ref , check_device = False , atol = 1e-3 , rtol = 1e-3 )
789833
790834
@@ -913,7 +957,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
913957 options .benchmark_results_file = perf_filename_iree
914958
915959 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
916- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
960+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
917961 assert_close (c , iree_ref , atol = 2e-4 , rtol = 3e-4 , check_device = False )
918962
919963
@@ -1044,7 +1088,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
10441088 options .benchmark_results_file = perf_filename_iree
10451089
10461090 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .int32 )
1047- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
1091+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
10481092 assert_close (c , iree_ref , check_device = False )
10491093
10501094
@@ -1151,7 +1195,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
11511195 options .benchmark_results_file = perf_filename_iree
11521196
11531197 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .int32 )
1154- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
1198+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
11551199 assert_close (c , iree_ref , check_device = False )
11561200
11571201
@@ -1255,7 +1299,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
12551299 options .benchmark_results_file = perf_filename_iree
12561300
12571301 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
1258- generate_iree_ref ("mmt_f8" , [a , b ], [iree_ref ])
1302+ generate_iree_ref ("mmt_f8" , [a , b ], [iree_ref ], options )
12591303 assert_close (c , iree_ref , atol = 3e-5 , rtol = 3e-4 , check_device = False )
12601304
12611305
@@ -1382,7 +1426,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
13821426 options .benchmark_results_file = perf_filename_iree
13831427
13841428 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
1385- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
1429+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
13861430 assert_close (c , iree_ref , check_device = False )
13871431
13881432
@@ -1516,7 +1560,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
15161560 options .benchmark_results_file = perf_filename_iree
15171561
15181562 iree_ref = device_zeros (shape [0 ], shape [1 ], dtype = torch .float32 )
1519- generate_iree_ref ("mmt" , [a , b ], [iree_ref ])
1563+ generate_iree_ref ("mmt" , [a , b ], [iree_ref ], options )
15201564 assert_close (c , iree_ref , check_device = False )
15211565
15221566
@@ -1615,7 +1659,7 @@ def repeat(
16151659 options .benchmark_results_file = perf_filename_iree
16161660
16171661 iree_ref = device_zeros (shape [0 ], shape [1 ], shape [2 ], dtype = torch .float32 )
1618- generate_iree_ref ("bmmt" , [a , b ], [iree_ref ])
1662+ generate_iree_ref ("bmmt" , [a , b ], [iree_ref ], options )
16191663 assert_close (c , iree_ref , check_device = False )
16201664
16211665 torch_ref = torch .matmul (a , b .transpose (- 2 , - 1 ))
@@ -1719,7 +1763,7 @@ def repeat(
17191763 options .benchmark_results_file = perf_filename_iree
17201764
17211765 iree_ref = device_zeros (shape [0 ], shape [1 ], shape [2 ], dtype = torch .float32 )
1722- generate_iree_ref ("bmmt" , [a , b ], [iree_ref ])
1766+ generate_iree_ref ("bmmt" , [a , b ], [iree_ref ], options )
17231767 assert_close (c , iree_ref , check_device = False )
17241768
17251769 torch_ref = torch .matmul (a , b .transpose (- 2 , - 1 ))
0 commit comments