Skip to content

Commit c478b34

Browse files
authored
Fix TK gemmbench to work with new iree-turbine API (#70)
This commit is patterned after a597b1f, which solved the same problems in convbench for the same reasons.
1 parent 542bb44 commit c478b34

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

iree_kernel_benchmark/gemmbench/gemm_utils.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import iree.turbine.kernel.lang as tkl
88
import iree.turbine.kernel.wave as tkw
99
from iree.turbine.kernel.lang.global_symbols import *
10-
from iree.turbine.kernel.wave.utils import (
11-
get_default_run_config,
10+
from iree.turbine.kernel.wave.compile import wave_compile, WaveCompileOptions
11+
from iree.turbine.kernel.wave.utils.general_utils import (
1212
get_default_scheduling_params,
1313
)
14-
except ImportError:
14+
from iree.turbine.kernel.wave.scheduling.schedule_enums import SchedulingType
15+
except ImportError as e:
1516
TURBINE_AVAILABLE = False
17+
turbine_import_error = e
1618
else:
1719
TURBINE_AVAILABLE = True
1820

@@ -262,7 +264,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
262264
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
263265

264266
shape = [config.M, config.N, config.K]
265-
schedule = config.K < 4096
267+
schedule = SchedulingType.MODULO if config.K < 4096 else SchedulingType.NONE
266268

267269
hyperparams = {
268270
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
@@ -276,20 +278,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
276278
K: shape[2],
277279
}
278280
hyperparams.update(get_default_scheduling_params())
279-
# config = get_default_run_config() TODO: detects device as CPU for some reason
280-
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
281281

282282
# TODO: Scheduling is taking too long time with large K.
283-
with tk.gen.TestLaunchContext(
284-
hyperparams,
283+
options = WaveCompileOptions(
284+
subs=hyperparams,
285285
canonicalize=True,
286286
create_vmfb_file=vmfb_file,
287-
run_config=config,
287+
backend="rocm",
288+
target="gfx942",
288289
schedule=schedule,
289-
):
290-
mb = gemm()
291-
292-
return mb.module_op.get_asm()
290+
)
291+
result = wave_compile(options, gemm)
292+
return result.asm
293293

294294

295295
def compile_gemm_config(
@@ -309,7 +309,9 @@ def compile_gemm_config(
309309

310310
# Generate mlir content
311311
if tk and not TURBINE_AVAILABLE:
312-
raise ValueError("Requested TK benchmarks but Turbine isn't available")
312+
raise ValueError(
313+
f"Can't compile TK benchmark because of a failed import (most likely iree.turbine is missing): {turbine_import_error}"
314+
)
313315
if tk:
314316
try:
315317
mlir_content = generate_tk_mlir(config, vmfb_file)

0 commit comments

Comments
 (0)