@@ -248,11 +248,14 @@ def process_boo_command(
248248 boo_path_config : BooPathConfig ,
249249 root_logger : logging .Logger ,
250250 starter_td_spec : Path | None ,
251- boo_runtime ,
252- get_launchable ,
253- BooOpRegistry ,
254251) -> Path | None :
255252 """Process a single BOO command through compilation and tuning."""
253+ # These imports are slow due to a pytorch dependency. Keeping them local helps
254+ # make '--help' fast.
255+ from iree .turbine .kernel .boo .driver .launch import get_launchable
256+ from iree .turbine .kernel .boo import runtime as boo_runtime
257+ from iree .turbine .kernel .boo .op_exports .registry import BooOpRegistry
258+
256259 sig = BooOpRegistry .parse_command (cli_args , ignore_unhandled_args = True )
257260 if sig is None :
258261 raise ValueError (f"Boo op registry failed to parse '{ shlex .join (cli_args )} '." )
@@ -272,8 +275,14 @@ def process_boo_command(
272275
273276 # Run BOO compilation and extract source IR.
274277 with boo_runtime .use_cache_dir (boo_cache_dir ):
278+ # The "iree_boo" backend offloads to IREE in cases where we expect
279+ # performance to be better, and falls back to pytorch otherwise. We use
280+ # the experimental backend here instead, as we want to use IREE in all
281+ # cases.
275282 # Note: device="cuda" is correct for AMD GPUs.
276- get_launchable (sig )(* sig .get_sample_args (device = "cuda" , seed = 123 ))
283+ sig .get_compiled_module (backend = "iree_boo_experimental" )(
284+ * sig .get_sample_args (device = "cuda" , seed = 123 )
285+ )
277286 [op_cache_dir ] = os .listdir (boo_cache_dir )
278287 op_cache_path = boo_cache_dir / op_cache_dir
279288
@@ -347,20 +356,12 @@ def process_boo_command(
347356 return args .output_td_spec if best_spec_path else None
348357
349358
350- def load_boo () -> tuple [types .ModuleType , Callable , type ]:
351- """Load BOO runtime modules.
352-
353- These imports are slow due to a pytorch dependency. Keeping them in a
354- separate function helps make '--help' fast.
355- """
356- from iree .turbine .kernel .boo import runtime as boo_runtime
357- from iree .turbine .kernel .boo .driver .launch import get_launchable
358- from iree .turbine .kernel .boo .op_exports .registry import BooOpRegistry
359-
360- return boo_runtime , get_launchable , BooOpRegistry
361-
362-
363359def main () -> None :
360+ # Set saner defaults for pytorch/miopen environment variables. This affects
361+ # pytorch's inferred tensor layouts on AMDGPU, even when not actually using
362+ # MIOpen kernels, and are required for performance.
363+ os .environ .setdefault ("PYTORCH_MIOPEN_SUGGEST_NHWC" , "1" )
364+
364365 parsed_args : tuple [argparse .Namespace , list [str ]] = parse_args ()
365366 args , miopen_op_args = parsed_args
366367
@@ -382,7 +383,6 @@ def main() -> None:
382383 libtuner .validate_devices (args .devices )
383384 logging .info ("Validation successful!" )
384385
385- boo_runtime , get_launchable , BooOpRegistry = load_boo ()
386386 logging .getLogger ("turbine" ).setLevel (logging .WARNING )
387387
388388 mio_args = load_commands_from_file_or_args (args .commands_file , miopen_op_args )
@@ -398,9 +398,6 @@ def main() -> None:
398398 boo_path_config ,
399399 root_logger ,
400400 starter_td_spec ,
401- boo_runtime ,
402- get_launchable ,
403- BooOpRegistry ,
404401 )
405402
406403 # Update starter spec for next iteration if tuning succeeded.
0 commit comments