115115from ..compiler import host_codegen , kernel_codegen , builder , dispatch_codegen
116116from ..compiler .wave_codegen import WaveEmitter
117117from .compile_options import WaveCompileOptions
118+ from pathlib import Path
119+
118120from .cache import (
121+ get_cache_base_dir ,
119122 get_cache_manager ,
120123 get_temp_binary_dir ,
121124 is_cache_enabled ,
@@ -402,8 +405,10 @@ def __init__(
402405 return
403406 self ._engine = get_execution_engine ()
404407 self ._module_handle = self ._engine .load_module_from_text (optimized_mlir )
408+ self ._bind_host_func ()
405409
406- # Look up the host wrapper function
410+ def _bind_host_func (self ):
411+ """Look up the host wrapper function and create a ctypes callable."""
407412 func_name = self .options .func_name
408413 try :
409414 self ._host_func_ptr = self ._engine .lookup (self ._module_handle , func_name )
@@ -415,14 +420,38 @@ def __init__(
415420
416421 # Create ctypes function type
417422 # The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
418-
419423 num_kernel_args = len (self .options .kernel_usages )
420424 arg_types = [ctypes .c_void_p ] + [
421425 py_object
422426 ] * num_kernel_args # +1 for stream pointer
423427 func_type = ctypes .CFUNCTYPE (None , * arg_types )
424428 self ._cfunc = func_type (self ._host_func_ptr )
425429
430+ def dump_to_object_file (self , path : str ):
431+ """Dump the compiled host object file (with embedded GPU binary) to disk."""
432+ assert self ._engine is not None , "no execution engine to dump from"
433+ self ._engine .dump_to_object_file (path )
434+
435+ @classmethod
436+ def from_object_file (
437+ cls ,
438+ options : WaveCompileOptions ,
439+ object_file_path : str ,
440+ mlir_asm : str = "" ,
441+ ) -> "WaveKernelExecutionEngine" :
442+ """Load a cached object file instead of compiling from MLIR."""
443+ from wave_lang .kernel .wave .execution_engine import get_execution_engine
444+
445+ instance = cls .__new__ (cls )
446+ instance .options = options
447+ instance .asm = mlir_asm
448+ instance ._engine = get_execution_engine ()
449+ instance ._module_handle = instance ._engine .load_from_object_file (
450+ object_file_path
451+ )
452+ instance ._bind_host_func ()
453+ return instance
454+
426455 def __call__ (self , * args ):
427456 return self .invoke (* args )
428457
@@ -1013,6 +1042,11 @@ def get_binary_path():
10131042 else :
10141043 return glob .glob (str (get_temp_binary_dir () / "*.hsaco" ))[0 ]
10151044
1045+ def _get_water_object_cache_path (kernel_hash : str ) -> Path :
1046+ """Return the path for a cached Water object file."""
1047+ base = cache_manager .base_dir if cache_manager else get_cache_base_dir ()
1048+ return base / kernel_hash / (kernel_hash + ".o" )
1049+
10161050 # Create an indexing context and populate substitutions.
10171051 with IndexingContext () as idxc :
10181052 idxc .set_subs (options .subs )
@@ -1037,22 +1071,32 @@ def get_binary_path():
10371071 if cached_kernel :
10381072 options .kernel_usages = cached_kernel .kernel_sig
10391073 options .kernel_launch_info = cached_kernel .kernel_launch_info
1040- if options .wave_runtime :
1041- binary_path = get_binary_path ()
10421074
10431075 if options .print_mlir :
10441076 print (cached_kernel .asm )
10451077
1046- return cls (
1047- options ,
1048- cached_kernel .vmfb ,
1049- cached_kernel .asm ,
1050- binary_path ,
1051- bound_scalar_symbols ,
1052- symbols_args_map ,
1053- None ,
1054- None ,
1055- )
1078+ if options .use_water_backend :
1079+ obj_path = _get_water_object_cache_path (options .kernel_hash )
1080+ if obj_path .exists ():
1081+ return WaveKernelExecutionEngine .from_object_file (
1082+ options , str (obj_path ), cached_kernel .asm
1083+ )
1084+ # Object file missing from cache, fall through
1085+ # to recompilation.
1086+ else :
1087+ if options .wave_runtime :
1088+ binary_path = get_binary_path ()
1089+
1090+ return cls (
1091+ options ,
1092+ cached_kernel .vmfb ,
1093+ cached_kernel .asm ,
1094+ binary_path ,
1095+ bound_scalar_symbols ,
1096+ symbols_args_map ,
1097+ None ,
1098+ None ,
1099+ )
10561100
10571101 # For the wave runtime, we need the hsaco binary. So we turn on
10581102 # dumping of binaries and store in wave runtime directory. If we
@@ -1176,12 +1220,25 @@ def get_binary_path():
11761220 _compile_asm_to_binary (asm , options )
11771221 elif options .use_water_backend :
11781222 module = water_lowering_pipeline (mb .module_op , options )
1179- return WaveKernelExecutionEngine (
1223+ engine = WaveKernelExecutionEngine (
11801224 options ,
11811225 module ,
11821226 asm ,
11831227 create_execution_engine = not options .compile_to_mlir ,
11841228 )
1229+ # Cache the compiled object file for future runs.
1230+ if (
1231+ is_cache_enabled ()
1232+ and cache_manager is not None
1233+ and options .kernel_hash
1234+ and not debug_arg_info
1235+ and not options .compile_to_mlir
1236+ ):
1237+ obj_path = _get_water_object_cache_path (options .kernel_hash )
1238+ obj_path .parent .mkdir (parents = True , exist_ok = True )
1239+ engine .dump_to_object_file (str (obj_path ))
1240+ cache_manager .store_kernel (None , asm , options )
1241+ return engine
11851242 elif not options .compile_to_mlir :
11861243 # LLVM flow: only compile to VMFB when not in MLIR-only mode
11871244 compiled_wave_vmfb = compile_to_vmfb (asm , options )
0 commit comments