7
7
import iree .turbine .kernel .lang as tkl
8
8
import iree .turbine .kernel .wave as tkw
9
9
from iree .turbine .kernel .lang .global_symbols import *
10
- from iree .turbine .kernel .wave .utils import (
11
- get_default_run_config ,
10
+ from iree .turbine .kernel .wave .compile import wave_compile , WaveCompileOptions
11
+ from iree . turbine . kernel . wave . utils . general_utils import (
12
12
get_default_scheduling_params ,
13
13
)
14
- except ImportError :
14
+ from iree .turbine .kernel .wave .scheduling .schedule_enums import SchedulingType
15
+ except ImportError as e :
15
16
TURBINE_AVAILABLE = False
17
+ turbine_import_error = e
16
18
else :
17
19
TURBINE_AVAILABLE = True
18
20
@@ -262,7 +264,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
262
264
tkw .write (repeat , c , elements_per_thread = STORE_ELEMS_PER_THREAD )
263
265
264
266
shape = [config .M , config .N , config .K ]
265
- schedule = config .K < 4096
267
+ schedule = SchedulingType . MODULO if config .K < 4096 else SchedulingType . NONE
266
268
267
269
hyperparams = {
268
270
ADDRESS_SPACE : SHARED_ADDRESS_SPACE ,
@@ -276,20 +278,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
276
278
K : shape [2 ],
277
279
}
278
280
hyperparams .update (get_default_scheduling_params ())
279
- # config = get_default_run_config() TODO: detects device as CPU for some reason
280
- config = {"backend" : "rocm" , "device" : "hip" , "target" : "gfx942" }
281
281
282
282
# TODO: Scheduling is taking too long time with large K.
283
- with tk . gen . TestLaunchContext (
284
- hyperparams ,
283
+ options = WaveCompileOptions (
284
+ subs = hyperparams ,
285
285
canonicalize = True ,
286
286
create_vmfb_file = vmfb_file ,
287
- run_config = config ,
287
+ backend = "rocm" ,
288
+ target = "gfx942" ,
288
289
schedule = schedule ,
289
- ):
290
- mb = gemm ()
291
-
292
- return mb .module_op .get_asm ()
290
+ )
291
+ result = wave_compile (options , gemm )
292
+ return result .asm
293
293
294
294
295
295
def compile_gemm_config (
@@ -309,7 +309,9 @@ def compile_gemm_config(
309
309
310
310
# Generate mlir content
311
311
if tk and not TURBINE_AVAILABLE :
312
- raise ValueError ("Requested TK benchmarks but Turbine isn't available" )
312
+ raise ValueError (
313
+ f"Can't compile TK benchmark because of a failed import (most likely iree.turbine is missing): { turbine_import_error } "
314
+ )
313
315
if tk :
314
316
try :
315
317
mlir_content = generate_tk_mlir (config , vmfb_file )
0 commit comments