Skip to content

Commit a597b1f

Browse files
authored
Fix TK convbench to work with new iree-turbine API (#68)
iree-org/iree-turbine@172bfc6 made significant API changes that broke convbench's TK support. This commit changes convbench to use the new APIs. This commit also changes convbench to report the specific ImportError in case importing something from TK fails, rather than just claiming TK is missing (which can disguise the actual problem).
1 parent 26362bd commit a597b1f

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

iree_kernel_benchmark/convbench/wave_conv_utils.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
import iree.turbine.kernel as tk
1010
import iree.turbine.kernel.lang as tkl
1111
from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d
12-
from iree.turbine.kernel.wave.utils import (
13-
get_default_arch,
14-
get_default_run_config,
15-
get_default_compile_config,
12+
from iree.turbine.kernel.wave.compile import wave_compile, WaveCompileOptions
13+
from iree.turbine.kernel.wave.scheduling.schedule_enums import SchedulingType
14+
from iree.turbine.kernel.wave.utils.torch_utils import (
1615
device_randn,
1716
device_randint,
1817
device_randperm,
1918
device_zeros,
2019
)
21-
except ImportError:
20+
except ImportError as e:
2221
TURBINE_AVAILABLE = False
22+
turbine_import_error = e
2323
else:
2424
TURBINE_AVAILABLE = True
2525

@@ -32,7 +32,9 @@ def compile_wave_conv_config(
3232
extra_compiler_args: list[str],
3333
) -> tuple[Path, Optional[Path]]:
3434
if not TURBINE_AVAILABLE:
35-
raise ValueError("iree.turbine package is not available")
35+
raise ValueError(
36+
f"Can't compile TK benchmark because of a failed import (most likely iree.turbine is missing): {turbine_import_error}"
37+
)
3638

3739
# Name with tag is used for filenames so that duplicate configs with
3840
# different tags will not clobber eachother.
@@ -98,19 +100,17 @@ def _compile_conv(config: ConvConfig, mlir_file: Path, vmfb_file: Path):
98100
else:
99101
raise ValueError(f"Unsupported op_type: {op_type}")
100102

101-
# config = get_default_run_config()
102-
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
103-
104-
with tk.gen.TestLaunchContext(
105-
hyperparams,
103+
options = WaveCompileOptions(
104+
subs=hyperparams,
106105
canonicalize=True,
107106
create_vmfb_file=vmfb_file,
108-
run_config=config,
109-
schedule=False,
110-
inline=False,
111-
):
112-
mod = conv().module_op # This will generate vmfb file
113-
with open(mlir_file, "w") as f:
114-
f.write(str(mod))
115-
116-
print(f"Successfully compiled to {vmfb_file}")
107+
schedule=SchedulingType.NONE,
108+
# inline=False, (TODO: how to do this with new API?)
109+
backend="rocm",
110+
target="gfx942",
111+
)
112+
result = wave_compile(options, conv)
113+
with open(mlir_file, "w") as f:
114+
f.write(result.asm)
115+
116+
print(f"Successfully compiled to {vmfb_file}")

0 commit comments

Comments
 (0)