123123from ..compiler import host_codegen , kernel_codegen , builder , dispatch_codegen
124124from ..compiler .wave_codegen import WaveEmitter
125125from .compile_options import WaveCompileOptions
126+ from pathlib import Path
127+
126128from .cache import (
129+ get_cache_base_dir ,
127130 get_cache_manager ,
128131 get_temp_binary_dir ,
129132 is_cache_enabled ,
@@ -414,8 +417,10 @@ def __init__(
414417 return
415418 self ._engine = get_execution_engine ()
416419 self ._module_handle = self ._engine .load_module_from_text (optimized_mlir )
420+ self ._bind_host_func ()
417421
418- # Look up the host wrapper function
422+ def _bind_host_func (self ):
423+ """Look up the host wrapper function and create a ctypes callable."""
419424 func_name = self .options .func_name
420425 try :
421426 self ._host_func_ptr = self ._engine .lookup (self ._module_handle , func_name )
@@ -427,14 +432,38 @@ def __init__(
427432
428433 # Create ctypes function type
429434 # The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
430-
431435 num_kernel_args = len (self .options .kernel_usages )
432436 arg_types = [ctypes .c_void_p ] + [
433437 py_object
434438 ] * num_kernel_args # +1 for stream pointer
435439 func_type = ctypes .CFUNCTYPE (None , * arg_types )
436440 self ._cfunc = func_type (self ._host_func_ptr )
437441
442+ def dump_to_object_file (self , path : str ):
443+ """Dump the compiled host object file (with embedded GPU binary) to disk."""
444+ assert self ._engine is not None , "no execution engine to dump from"
445+ self ._engine .dump_to_object_file (path )
446+
447+ @classmethod
448+ def from_object_file (
449+ cls ,
450+ options : WaveCompileOptions ,
451+ object_file_path : str ,
452+ mlir_asm : str = "" ,
453+ ) -> "WaveKernelExecutionEngine" :
454+ """Load a cached object file instead of compiling from MLIR."""
455+ from wave_lang .kernel .wave .execution_engine import get_execution_engine
456+
457+ instance = cls .__new__ (cls )
458+ instance .options = options
459+ instance .asm = mlir_asm
460+ instance ._engine = get_execution_engine ()
461+ instance ._module_handle = instance ._engine .load_from_object_file (
462+ object_file_path
463+ )
464+ instance ._bind_host_func ()
465+ return instance
466+
438467 def __call__ (self , * args ):
439468 return self .invoke (* args )
440469
@@ -1034,6 +1063,11 @@ def get_binary_path():
10341063 else :
10351064 return glob .glob (str (get_temp_binary_dir () / "*.hsaco" ))[0 ]
10361065
1066+ def _get_water_object_cache_path (kernel_hash : str ) -> Path :
1067+ """Return the path for a cached Water object file."""
1068+ base = cache_manager .base_dir if cache_manager else get_cache_base_dir ()
1069+ return base / kernel_hash / (kernel_hash + ".o" )
1070+
10371071 # Create an indexing context and populate substitutions.
10381072 with IndexingContext () as idxc :
10391073 idxc .set_subs (options .subs )
@@ -1058,22 +1092,32 @@ def get_binary_path():
10581092 if cached_kernel :
10591093 options .kernel_usages = cached_kernel .kernel_sig
10601094 options .kernel_launch_info = cached_kernel .kernel_launch_info
1061- if options .wave_runtime :
1062- binary_path = get_binary_path ()
10631095
10641096 if options .print_mlir :
10651097 print (cached_kernel .asm )
10661098
1067- return cls (
1068- options ,
1069- cached_kernel .vmfb ,
1070- cached_kernel .asm ,
1071- binary_path ,
1072- bound_scalar_symbols ,
1073- symbols_args_map ,
1074- None ,
1075- None ,
1076- )
1099+ if options .use_water_backend :
1100+ obj_path = _get_water_object_cache_path (options .kernel_hash )
1101+ if obj_path .exists ():
1102+ return WaveKernelExecutionEngine .from_object_file (
1103+ options , str (obj_path ), cached_kernel .asm
1104+ )
1105+ # Object file missing from cache, fall through
1106+ # to recompilation.
1107+ else :
1108+ if options .wave_runtime :
1109+ binary_path = get_binary_path ()
1110+
1111+ return cls (
1112+ options ,
1113+ cached_kernel .vmfb ,
1114+ cached_kernel .asm ,
1115+ binary_path ,
1116+ bound_scalar_symbols ,
1117+ symbols_args_map ,
1118+ None ,
1119+ None ,
1120+ )
10771121
10781122 # For the wave runtime, we need the hsaco binary. So we turn on
10791123 # dumping of binaries and store in wave runtime directory. If we
@@ -1210,12 +1254,25 @@ def get_binary_path():
12101254 _compile_asm_to_binary (asm , options )
12111255 elif options .use_water_backend :
12121256 module = water_lowering_pipeline (mb .module_op , options )
1213- return WaveKernelExecutionEngine (
1257+ engine = WaveKernelExecutionEngine (
12141258 options ,
12151259 module ,
12161260 asm ,
12171261 create_execution_engine = not options .compile_to_mlir ,
12181262 )
1263+ # Cache the compiled object file for future runs.
1264+ if (
1265+ is_cache_enabled ()
1266+ and cache_manager is not None
1267+ and options .kernel_hash
1268+ and not debug_arg_info
1269+ and not options .compile_to_mlir
1270+ ):
1271+ obj_path = _get_water_object_cache_path (options .kernel_hash )
1272+ obj_path .parent .mkdir (parents = True , exist_ok = True )
1273+ engine .dump_to_object_file (str (obj_path ))
1274+ cache_manager .store_kernel (None , asm , options )
1275+ return engine
12191276 elif not options .compile_to_mlir :
12201277 # LLVM flow: only compile to VMFB when not in MLIR-only mode
12211278 compiled_wave_vmfb = compile_to_vmfb (asm , options )
0 commit comments