115115from ..compiler import host_codegen , kernel_codegen , builder , dispatch_codegen
116116from ..compiler .wave_codegen import WaveEmitter
117117from .compile_options import WaveCompileOptions
118+ from pathlib import Path
119+
118120from .cache import (
121+ get_cache_base_dir ,
119122 get_cache_manager ,
120123 get_temp_binary_dir ,
121124 is_cache_enabled ,
@@ -386,24 +389,23 @@ def __init__(
386389 self ._module_handle = None
387390 self ._host_func_ptr = None
388391
389- # Serialize MLIR module to text if needed
390- # TODO: investigate why bytecode deserialization is not working
392+ # Serialize MLIR module to text if needed.
393+ # TODO: investigate why bytecode deserialization is not working.
391394 if isinstance (module , (bytes , str )):
392- # Assume it's already MLIR text
393395 optimized_mlir = module .decode () if isinstance (module , bytes ) else module
394396 else :
395- # Serialize the MLIR module to text
396397 optimized_mlir = str (module )
397398
398- # Get the execution engine instance and load the module
399399 from wave_lang .kernel .wave .execution_engine import get_execution_engine
400400
401401 if not create_execution_engine :
402402 return
403403 self ._engine = get_execution_engine ()
404404 self ._module_handle = self ._engine .load_module_from_text (optimized_mlir )
405+ self ._bind_host_func ()
405406
406- # Look up the host wrapper function
407+ def _bind_host_func (self ):
408+ """Look up the host wrapper function and create a ctypes callable."""
407409 func_name = self .options .func_name
408410 try :
409411 self ._host_func_ptr = self ._engine .lookup (self ._module_handle , func_name )
@@ -413,32 +415,49 @@ def __init__(
413415 f"Make sure the module was compiled with emit_host_func. Error: { e } "
414416 )
415417
416- # Create ctypes function type
417- # The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
418-
418+ # The host wrapper signature is:
419+ # void func(void* stream, PyObject* arg0, PyObject* arg1, ...).
419420 num_kernel_args = len (self .options .kernel_usages )
420- arg_types = [ctypes .c_void_p ] + [
421- py_object
422- ] * num_kernel_args # +1 for stream pointer
421+ arg_types = [ctypes .c_void_p ] + [py_object ] * num_kernel_args
423422 func_type = ctypes .CFUNCTYPE (None , * arg_types )
424423 self ._cfunc = func_type (self ._host_func_ptr )
425424
425+ def dump_to_object_file (self , path : str ):
426+ """Dump the compiled host object file (with embedded GPU binary) to disk."""
427+ assert self ._engine is not None , "no execution engine to dump from"
428+ self ._engine .dump_to_object_file (path )
429+
430+ @classmethod
431+ def from_object_file (
432+ cls ,
433+ options : WaveCompileOptions ,
434+ object_file_path : str ,
435+ mlir_asm : str = "" ,
436+ ) -> "WaveKernelExecutionEngine" :
437+ """Load a cached object file instead of compiling from MLIR."""
438+ from wave_lang .kernel .wave .execution_engine import get_execution_engine
439+
440+ instance = cls .__new__ (cls )
441+ instance .options = options
442+ instance .asm = mlir_asm
443+ instance ._engine = get_execution_engine ()
444+ instance ._module_handle = instance ._engine .load_from_object_file (
445+ object_file_path
446+ )
447+ instance ._bind_host_func ()
448+ return instance
449+
426450 def __call__ (self , * args ):
427451 return self .invoke (* args )
428452
429453 def invoke (self , * args ) -> None :
430- """
431- Invokes the wave kernel with the given arguments using the ExecutionEngine.
432- """
454+ """Invoke the wave kernel with the given arguments using the ExecutionEngine."""
433455 assert (
434456 self ._engine is not None
435- ), "Cannot invoke kernel without creating an execution engine. Revise the constructor call. "
457+ ), "Cannot invoke kernel without creating an execution engine."
436458
437- # Get the current stream
438459 stream_ptr = torch .cuda .current_stream ().cuda_stream
439-
440- # Call the JIT-compiled host wrapper function
441- # Signature: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
460+ # Signature: void func(void* stream, PyObject* arg0, PyObject* arg1, ...).
442461 self ._cfunc (stream_ptr , * (py_object (arg ) for arg in args ))
443462
444463
@@ -1013,6 +1032,11 @@ def get_binary_path():
10131032 else :
10141033 return glob .glob (str (get_temp_binary_dir () / "*.hsaco" ))[0 ]
10151034
1035+ def _get_water_object_cache_path (kernel_hash : str ) -> Path :
1036+ """Return the path for a cached Water object file."""
1037+ base = cache_manager .base_dir if cache_manager else get_cache_base_dir ()
1038+ return base / kernel_hash / (kernel_hash + ".o" )
1039+
10161040 # Create an indexing context and populate substitutions.
10171041 with IndexingContext () as idxc :
10181042 idxc .set_subs (options .subs )
@@ -1037,22 +1061,32 @@ def get_binary_path():
10371061 if cached_kernel :
10381062 options .kernel_usages = cached_kernel .kernel_sig
10391063 options .kernel_launch_info = cached_kernel .kernel_launch_info
1040- if options .wave_runtime :
1041- binary_path = get_binary_path ()
10421064
10431065 if options .print_mlir :
10441066 print (cached_kernel .asm )
10451067
1046- return cls (
1047- options ,
1048- cached_kernel .vmfb ,
1049- cached_kernel .asm ,
1050- binary_path ,
1051- bound_scalar_symbols ,
1052- symbols_args_map ,
1053- None ,
1054- None ,
1055- )
1068+ if options .use_water_backend :
1069+ obj_path = _get_water_object_cache_path (options .kernel_hash )
1070+ if obj_path .exists ():
1071+ return WaveKernelExecutionEngine .from_object_file (
1072+ options , str (obj_path ), cached_kernel .asm
1073+ )
1074+ # Object file missing from cache, fall through
1075+ # to recompilation.
1076+ else :
1077+ if options .wave_runtime :
1078+ binary_path = get_binary_path ()
1079+
1080+ return cls (
1081+ options ,
1082+ cached_kernel .vmfb ,
1083+ cached_kernel .asm ,
1084+ binary_path ,
1085+ bound_scalar_symbols ,
1086+ symbols_args_map ,
1087+ None ,
1088+ None ,
1089+ )
10561090
10571091 # For the wave runtime, we need the hsaco binary. So we turn on
10581092 # dumping of binaries and store in wave runtime directory. If we
@@ -1176,12 +1210,25 @@ def get_binary_path():
11761210 _compile_asm_to_binary (asm , options )
11771211 elif options .use_water_backend :
11781212 module = water_lowering_pipeline (mb .module_op , options )
1179- return WaveKernelExecutionEngine (
1213+ engine = WaveKernelExecutionEngine (
11801214 options ,
11811215 module ,
11821216 asm ,
11831217 create_execution_engine = not options .compile_to_mlir ,
11841218 )
1219+ # Cache the compiled object file for future runs.
1220+ if (
1221+ is_cache_enabled ()
1222+ and cache_manager is not None
1223+ and options .kernel_hash
1224+ and not debug_arg_info
1225+ and not options .compile_to_mlir
1226+ ):
1227+ obj_path = _get_water_object_cache_path (options .kernel_hash )
1228+ obj_path .parent .mkdir (parents = True , exist_ok = True )
1229+ engine .dump_to_object_file (str (obj_path ))
1230+ cache_manager .store_kernel (None , asm , options )
1231+ return engine
11851232 elif not options .compile_to_mlir :
11861233 # LLVM flow: only compile to VMFB when not in MLIR-only mode
11871234 compiled_wave_vmfb = compile_to_vmfb (asm , options )
0 commit comments