Skip to content

Commit 4586fab

Browse files
authored
[boo tuner] Switch to newer BOO backend (#2670)
This switches the BOO tuner to use the newer `iree_boo_experimental` BOO backend, which uses `torch.compile` under the hood.
1 parent 1421bae commit 4586fab

File tree

3 files changed

+20
-23
lines changed

3 files changed

+20
-23
lines changed

sharktuner/boo_tuner/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ source ../iree-build/.env && export PYTHONPATH
3434
([IREE Turbine ROCm requirements](https://github.com/iree-org/iree-turbine/blob/main/pytorch-rocm-requirements.txt)):
3535

3636
```shell
37-
pip install --index-url https://download.pytorch.org/whl/rocm6.3 'torch>=2.5,<2.8'
37+
pip install --index-url https://download.pytorch.org/whl/rocm6.4 'torch>=2.9'
3838
```
3939

4040
### Install IREE Turbine

sharktuner/boo_tuner/boo_tuner.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,14 @@ def process_boo_command(
248248
boo_path_config: BooPathConfig,
249249
root_logger: logging.Logger,
250250
starter_td_spec: Path | None,
251-
boo_runtime,
252-
get_launchable,
253-
BooOpRegistry,
254251
) -> Path | None:
255252
"""Process a single BOO command through compilation and tuning."""
253+
# These imports are slow due to a pytorch dependency. Keeping them local helps
254+
# make '--help' fast.
255+
from iree.turbine.kernel.boo.driver.launch import get_launchable
256+
from iree.turbine.kernel.boo import runtime as boo_runtime
257+
from iree.turbine.kernel.boo.op_exports.registry import BooOpRegistry
258+
256259
sig = BooOpRegistry.parse_command(cli_args, ignore_unhandled_args=True)
257260
if sig is None:
258261
raise ValueError(f"Boo op registry failed to parse '{shlex.join(cli_args)}'.")
@@ -272,8 +275,14 @@ def process_boo_command(
272275

273276
# Run BOO compilation and extract source IR.
274277
with boo_runtime.use_cache_dir(boo_cache_dir):
278+
# The "iree_boo" backend offloads to IREE in cases where we expect
279+
# performance to be better, and falls back to pytorch otherwise. We use
280+
# the experimental backend here instead, as we want to use IREE in all
281+
# cases.
275282
# Note: device="cuda" is correct for AMD GPUs.
276-
get_launchable(sig)(*sig.get_sample_args(device="cuda", seed=123))
283+
sig.get_compiled_module(backend="iree_boo_experimental")(
284+
*sig.get_sample_args(device="cuda", seed=123)
285+
)
277286
[op_cache_dir] = os.listdir(boo_cache_dir)
278287
op_cache_path = boo_cache_dir / op_cache_dir
279288

@@ -347,20 +356,12 @@ def process_boo_command(
347356
return args.output_td_spec if best_spec_path else None
348357

349358

350-
def load_boo() -> tuple[types.ModuleType, Callable, type]:
351-
"""Load BOO runtime modules.
352-
353-
These imports are slow due to a pytorch dependency. Keeping them in a
354-
separate function helps make '--help' fast.
355-
"""
356-
from iree.turbine.kernel.boo import runtime as boo_runtime
357-
from iree.turbine.kernel.boo.driver.launch import get_launchable
358-
from iree.turbine.kernel.boo.op_exports.registry import BooOpRegistry
359-
360-
return boo_runtime, get_launchable, BooOpRegistry
361-
362-
363359
def main() -> None:
360+
# Set saner defaults for pytorch/miopen environment variables. This affects
361+
# pytorch's inferred tensor layouts on AMDGPU, even when not actually using
362+
# MIOpen kernels, and are required for performance.
363+
os.environ.setdefault("PYTORCH_MIOPEN_SUGGEST_NHWC", "1")
364+
364365
parsed_args: tuple[argparse.Namespace, list[str]] = parse_args()
365366
args, miopen_op_args = parsed_args
366367

@@ -382,7 +383,6 @@ def main() -> None:
382383
libtuner.validate_devices(args.devices)
383384
logging.info("Validation successful!")
384385

385-
boo_runtime, get_launchable, BooOpRegistry = load_boo()
386386
logging.getLogger("turbine").setLevel(logging.WARNING)
387387

388388
mio_args = load_commands_from_file_or_args(args.commands_file, miopen_op_args)
@@ -398,9 +398,6 @@ def main() -> None:
398398
boo_path_config,
399399
root_logger,
400400
starter_td_spec,
401-
boo_runtime,
402-
get_launchable,
403-
BooOpRegistry,
404401
)
405402

406403
# Update starter spec for next iteration if tuning succeeded.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
convbfp16 -n 3 -c 2016 -H 1 -W 1 -k 224 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 1 -F 1 -t 1
1+
convbfp16 -n 3 -c 2016 -H 4 -W 4 -k 224 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 1 -F 1 -t 1
22
gemmfp16 --a_w 32 --a_h 64 --b_w 128

0 commit comments

Comments
 (0)