Skip to content

Commit 7050dc6

Browse files
authored
[Wave] Fix IREE benchmarking (again) (#31)
And add a test this time. --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent 6188974 commit 7050dc6

File tree

7 files changed

+101
-31
lines changed

7 files changed

+101
-31
lines changed

tests/kernel/wave/attention/chained_gemm_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def repeat(
166166
print(f"IR dumped to {filename}")
167167

168168
iree_ref = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
169-
generate_iree_ref("chain_mmt", [q, k, v], [iree_ref])
169+
generate_iree_ref("chain_mmt", [q, k, v], [iree_ref], options)
170170
assert_close(output, iree_ref, check_device=False, atol=0, rtol=0)
171171

172172
torch_qk = torch.matmul(q, k.transpose(-1, -2))
@@ -311,5 +311,5 @@ def repeat(
311311
f.write(asm)
312312

313313
iree_ref = device_zeros(batch, v_head_dim, q_seq_len, dtype=torch.float32)
314-
generate_iree_ref("chain_mmt_f8", [q, k, v], [iree_ref])
314+
generate_iree_ref("chain_mmt_f8", [q, k, v], [iree_ref], options)
315315
assert_close(output, iree_ref, atol=7e-5, rtol=2e-3, check_device=False)

tests/kernel/wave/reordered_gemm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,5 @@ def testReorderedPureGemm(
8989
options.benchmark_results_file = perf_filename_iree
9090

9191
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
92-
generate_iree_ref("mmt", [a, b], [iree_ref])
92+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
9393
assert_close(c, iree_ref, check_device=False)

tests/kernel/wave/wave_e2e_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,7 @@ def test_igemm_conv(
15451545
"conv_2d_" + layout,
15461546
[x, we],
15471547
[iree_ref],
1548+
options,
15481549
)
15491550

15501551

tests/kernel/wave/wave_gemm_test.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,50 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
6868
return default_test_shapes[test_name]
6969

7070

71+
@require_e2e
72+
def testGemmBench(tmp_path):
73+
shape = (64, 64, 64)
74+
perf_filename_tk = tmp_path / "wave_gemm_bench.txt"
75+
perf_filename_iree = tmp_path / "iree_gemm_bench.txt"
76+
enable_scheduling = SchedulingType.NONE
77+
dynamic_dims = False
78+
mfma_variant = MMAType.F32_16x16x16_F16
79+
gemm, hyperparams, dynamic_symbols = get_gemm_kernel(
80+
shape, dynamic_dims, mfma_variant, torch.float16
81+
)
82+
83+
assert not perf_filename_tk.exists()
84+
85+
options = WaveCompileOptions(
86+
subs=hyperparams,
87+
canonicalize=True,
88+
run_bench=True,
89+
schedule=enable_scheduling,
90+
use_scheduling_barriers=enable_scheduling_barriers,
91+
dynamic_symbols=dynamic_symbols,
92+
benchmark_batch_size=10,
93+
benchmark_repetitions=3,
94+
benchmark_results_file=perf_filename_tk,
95+
)
96+
options = set_default_run_config(options)
97+
gemm = wave_compile(options, gemm)
98+
99+
a = device_randn(shape[0], shape[2], dtype=torch.float16)
100+
b = device_randn(shape[1], shape[2], dtype=torch.float16)
101+
c = device_zeros(shape[0], shape[1], dtype=torch.float32)
102+
gemm(a, b, c)
103+
assert perf_filename_tk.exists()
104+
assert "real_time" in perf_filename_tk.read_text()
105+
106+
assert not perf_filename_iree.exists()
107+
options.benchmark_results_file = perf_filename_iree
108+
109+
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
110+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
111+
assert perf_filename_iree.exists()
112+
assert "real_time" in perf_filename_iree.read_text()
113+
114+
71115
@require_e2e
72116
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
73117
@pytest.mark.parametrize(
@@ -130,7 +174,7 @@ def testPureGemm(
130174
options.benchmark_results_file = perf_filename_iree
131175

132176
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
133-
generate_iree_ref("mmt", [a, b], [iree_ref])
177+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
134178
assert_close(c, iree_ref, check_device=False)
135179

136180

@@ -202,7 +246,7 @@ def testGemmGatherToLDS(
202246
options.benchmark_results_file = perf_filename_iree
203247

204248
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
205-
generate_iree_ref("mmt", [a, b], [iree_ref])
249+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
206250
assert_close(c, iree_ref, check_device=False)
207251

208252

@@ -336,7 +380,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
336380
options.benchmark_results_file = perf_filename_iree
337381

338382
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
339-
generate_iree_ref("mmt", [a, b], [iree_ref])
383+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
340384
assert_close(c, iree_ref, check_device=False)
341385

342386

@@ -574,7 +618,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
574618
options.benchmark_results_file = perf_filename_iree
575619

576620
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
577-
generate_iree_ref("mmt", [a, b], [iree_ref])
621+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
578622
assert_close(c, iree_ref, check_device=False)
579623

580624

@@ -627,7 +671,7 @@ def testGemmDumpOverrideSchedule(
627671
options.benchmark_results_file = perf_filename_iree
628672

629673
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
630-
generate_iree_ref("mmt", [a, b], [iree_ref])
674+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
631675
assert_close(c, iree_ref, check_device=False)
632676

633677
# Now reload the schedule and run the kernel again.
@@ -784,7 +828,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
784828
options.benchmark_results_file = perf_filename_iree
785829

786830
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
787-
generate_iree_ref("mmt", [a, b], [iree_ref])
831+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
788832
assert_close(c, iree_ref, check_device=False, atol=1e-3, rtol=1e-3)
789833

790834

@@ -913,7 +957,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
913957
options.benchmark_results_file = perf_filename_iree
914958

915959
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
916-
generate_iree_ref("mmt", [a, b], [iree_ref])
960+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
917961
assert_close(c, iree_ref, atol=2e-4, rtol=3e-4, check_device=False)
918962

919963

@@ -1044,7 +1088,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
10441088
options.benchmark_results_file = perf_filename_iree
10451089

10461090
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.int32)
1047-
generate_iree_ref("mmt", [a, b], [iree_ref])
1091+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
10481092
assert_close(c, iree_ref, check_device=False)
10491093

10501094

@@ -1151,7 +1195,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]:
11511195
options.benchmark_results_file = perf_filename_iree
11521196

11531197
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.int32)
1154-
generate_iree_ref("mmt", [a, b], [iree_ref])
1198+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
11551199
assert_close(c, iree_ref, check_device=False)
11561200

11571201

@@ -1255,7 +1299,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
12551299
options.benchmark_results_file = perf_filename_iree
12561300

12571301
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
1258-
generate_iree_ref("mmt_f8", [a, b], [iree_ref])
1302+
generate_iree_ref("mmt_f8", [a, b], [iree_ref], options)
12591303
assert_close(c, iree_ref, atol=3e-5, rtol=3e-4, check_device=False)
12601304

12611305

@@ -1382,7 +1426,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
13821426
options.benchmark_results_file = perf_filename_iree
13831427

13841428
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
1385-
generate_iree_ref("mmt", [a, b], [iree_ref])
1429+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
13861430
assert_close(c, iree_ref, check_device=False)
13871431

13881432

@@ -1516,7 +1560,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
15161560
options.benchmark_results_file = perf_filename_iree
15171561

15181562
iree_ref = device_zeros(shape[0], shape[1], dtype=torch.float32)
1519-
generate_iree_ref("mmt", [a, b], [iree_ref])
1563+
generate_iree_ref("mmt", [a, b], [iree_ref], options)
15201564
assert_close(c, iree_ref, check_device=False)
15211565

15221566

@@ -1615,7 +1659,7 @@ def repeat(
16151659
options.benchmark_results_file = perf_filename_iree
16161660

16171661
iree_ref = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
1618-
generate_iree_ref("bmmt", [a, b], [iree_ref])
1662+
generate_iree_ref("bmmt", [a, b], [iree_ref], options)
16191663
assert_close(c, iree_ref, check_device=False)
16201664

16211665
torch_ref = torch.matmul(a, b.transpose(-2, -1))
@@ -1719,7 +1763,7 @@ def repeat(
17191763
options.benchmark_results_file = perf_filename_iree
17201764

17211765
iree_ref = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
1722-
generate_iree_ref("bmmt", [a, b], [iree_ref])
1766+
generate_iree_ref("bmmt", [a, b], [iree_ref], options)
17231767
assert_close(c, iree_ref, check_device=False)
17241768

17251769
torch_ref = torch.matmul(a, b.transpose(-2, -1))

wave_lang/kernel/wave/iree_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
from wave_lang.runtime.launch import Launchable
1010
from wave_lang.support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM
11+
from .utils.run_utils import get_benchmark_flags, print_bench_result
12+
from .profiling import benchmark_module
13+
from .utils.compile_utils import compile_to_vmfb
14+
import iree.runtime as rt
1115

1216

1317
def get_chain_mmt_asm(
@@ -161,6 +165,7 @@ def generate_iree_ref(
161165
kernel_type: str,
162166
kernel_inputs: list[torch.Tensor],
163167
kernel_outputs: list[torch.Tensor],
168+
options: "WaveCompileOptions",
164169
):
165170
"""
166171
Generate a reference output for the given kernel type and arguments.
@@ -211,10 +216,30 @@ def generate_iree_ref(
211216
else:
212217
raise ValueError(f"Unknown kernel type: {kernel_type}")
213218

214-
launchable = Launchable.jit_compile(asm, entry_point=func_name)
219+
vmfb = compile_to_vmfb(asm, options)
220+
221+
def loader(device):
222+
vm_instance = device.vm_instance
223+
return rt.VmModule.copy_buffer(vm_instance, vmfb)
224+
225+
launchable = Launchable.from_vm_module(loader, entry_point=func_name)
215226
res = launchable(*kernel_inputs, outputs=kernel_outputs)
216227
if len(kernel_outputs) == 1:
217228
kernel_outputs[0][:] = res
218229
else:
219230
for r, k in zip(res, kernel_outputs):
220231
k[:] = r
232+
233+
if options.run_bench:
234+
benchmark_flags = get_benchmark_flags(options)
235+
236+
benchmark_results = benchmark_module(
237+
options,
238+
kernel_inputs,
239+
[], # kernel_outputs,
240+
[], # dynamic_symbols,
241+
vmfb,
242+
func_name,
243+
**benchmark_flags,
244+
)
245+
print_bench_result(benchmark_results, options.benchmark_results_file)

wave_lang/kernel/wave/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def benchmark_module(
165165

166166
err = err.decode()
167167
if "INVALID_ARGUMENT;" in err:
168-
raise ValueError("Invalid inputs specified for benchmarking")
168+
raise ValueError(f"Invalid inputs specified for benchmarking:\n{err}")
169169

170170
# In the event benchmarking runs but encounteres an internal error,
171171
# return the internal error instead of benchmark results.

wave_lang/kernel/wave/utils/run_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def push_tensor_to_arg_list(arg_tensor: torch.Tensor):
8787
) from e
8888

8989

90-
def _print_bench_result(result, filename):
90+
def print_bench_result(result, filename):
9191
import json
9292

9393
res = json.dumps(result, sort_keys=True, indent=4)
@@ -97,6 +97,15 @@ def _print_bench_result(result, filename):
9797
print(res)
9898

9999

100+
def get_benchmark_flags(options: WaveCompileOptions):
101+
benchmark_flags = {}
102+
benchmark_flags["batch_size"] = options.benchmark_batch_size
103+
104+
if options.benchmark_repetitions is not None:
105+
benchmark_flags["benchmark_repetitions"] = int(options.benchmark_repetitions)
106+
return benchmark_flags
107+
108+
100109
def invoke_vmfb(
101110
vmfb: bytes,
102111
options: WaveCompileOptions,
@@ -120,16 +129,6 @@ def invoke_vmfb(
120129
return
121130

122131
device = options.device
123-
if options.run_bench:
124-
benchmark_flags = {}
125-
# If we use 1000 for bench_batch_size during compilation, and set this batch size to 1,
126-
# then the latency is in milliseconds.
127-
benchmark_flags["batch_size"] = 1
128-
129-
if options.benchmark_repetitions is not None:
130-
benchmark_flags["benchmark_repetitions"] = int(
131-
options.benchmark_repetitions
132-
)
133132

134133
# Select device as the GPU, where input tensors are coming from.
135134
device_list = tuple(
@@ -170,6 +169,7 @@ def invoke_vmfb(
170169
)
171170

172171
if options.run_bench:
172+
benchmark_flags = get_benchmark_flags(options)
173173
benchmark_results = benchmark_module(
174174
options,
175175
kernel_inputs,
@@ -179,7 +179,7 @@ def invoke_vmfb(
179179
options.func_name,
180180
**benchmark_flags,
181181
)
182-
_print_bench_result(benchmark_results, options.benchmark_results_file)
182+
print_bench_result(benchmark_results, options.benchmark_results_file)
183183

184184

185185
def invoke_with_wave_runtime(

0 commit comments

Comments
 (0)