Skip to content

Commit 3694ae0

Browse files
Hardcode84claude
andcommitted
Add disk caching for Water backend via object files
After JIT compilation, dump the host object file (which embeds the GPU binary) to the cache directory. On cache hit, load the .o directly into the execution engine, skipping MLIR parsing, LLVM IR translation, and host code compilation entirely. - Add ExecutionEngine::loadFromObjectFile (C++ + Python binding). - Enable object cache by default so dump_to_object_file works. - Include use_water_backend in cache hash to avoid collisions. - Add testWaterBackendCache covering miss, store, hit, and correctness. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
1 parent bceb332 commit 3694ae0

File tree

7 files changed

+231
-34
lines changed

7 files changed

+231
-34
lines changed

tests/kernel/runtime/cache_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
require_cdna_3_or_4,
5252
require_rdna4,
5353
require_e2e,
54+
require_water_and_ee,
5455
)
5556

5657
require_cache = pytest.mark.skipif(
@@ -868,3 +869,87 @@ def simple_copy(
868869
assert kernel2.gpu_binary_path.endswith(
869870
".hsaco"
870871
), "Expected .hsaco extension for cached kernel"
872+
873+
874+
@require_e2e
875+
@require_cache
876+
@require_water_and_ee
877+
def testWaterBackendCache(tmp_path):
878+
"""Test that Water backend object file caching works correctly."""
879+
880+
reset_cache_manager(tmp_path)
881+
882+
M = tkl.sym.M
883+
N = tkl.sym.N
884+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
885+
886+
wave_size = 64
887+
BLOCK_M = 1
888+
BLOCK_N = 256
889+
890+
constraints: list[tkw.Constraint] = [
891+
tkw.HardwareConstraint(
892+
threads_per_wave=wave_size,
893+
vector_shapes={M: BLOCK_M, N: BLOCK_N},
894+
),
895+
tkw.WorkgroupConstraint(M, BLOCK_M, 1),
896+
tkw.WorkgroupConstraint(N, BLOCK_N, 0),
897+
tkw.WaveConstraint(M, BLOCK_M),
898+
tkw.WaveConstraint(N, BLOCK_N),
899+
]
900+
901+
@tkw.wave(constraints)
902+
def simple_copy(
903+
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
904+
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
905+
):
906+
res = tkw.read(a)
907+
tkw.write(res, b)
908+
909+
hyperparams = {
910+
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
911+
M: 16,
912+
N: 256,
913+
}
914+
915+
cache_manager = get_cache_manager()
916+
917+
options = WaveCompileOptions(
918+
subs=copy.deepcopy(hyperparams),
919+
canonicalize=True,
920+
use_water_backend=True,
921+
)
922+
options = set_default_run_config(options)
923+
924+
# Before compilation, nothing in cache.
925+
assert len(cache_manager.session_cache) == 0, "Expected empty cache at start."
926+
927+
# First compilation -- cache miss, should produce and store an object file.
928+
kernel1 = wave_compile(options, simple_copy)
929+
930+
assert (
931+
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0
932+
), "Expected first compilation to be a cache miss."
933+
assert len(cache_manager.session_cache) == 1, "Expected one entry in session cache."
934+
935+
# Verify object file was written to cache directory.
936+
kernel_hash = options.kernel_hash
937+
obj_path = tmp_path / kernel_hash / (kernel_hash + ".o")
938+
assert obj_path.exists(), f"Expected object file at {obj_path}."
939+
assert obj_path.stat().st_size > 0, "Object file should not be empty."
940+
941+
a = device_randn(16, 256, dtype=torch.float16)
942+
b = device_zeros(16, 256, dtype=torch.float16)
943+
kernel1(a, b)
944+
assert_close(a, b)
945+
946+
# Second compilation -- cache hit, should load from object file.
947+
b2 = device_zeros(16, 256, dtype=torch.float16)
948+
kernel2 = wave_compile(options, simple_copy)
949+
950+
assert (
951+
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 1
952+
), "Expected second compilation to be a cache hit."
953+
954+
kernel2(a, b2)
955+
assert_close(a, b2)

wave_lang/kernel/wave/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def get_hash(
214214
options.reorder_allocs,
215215
options.override_schedule,
216216
options.use_bound_check,
217+
options.use_water_backend,
217218
]
218219

219220
# Add kernel/helper function specific hashes.

wave_lang/kernel/wave/compile.py

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@
115115
from ..compiler import host_codegen, kernel_codegen, builder, dispatch_codegen
116116
from ..compiler.wave_codegen import WaveEmitter
117117
from .compile_options import WaveCompileOptions
118+
from pathlib import Path
119+
118120
from .cache import (
121+
get_cache_base_dir,
119122
get_cache_manager,
120123
get_temp_binary_dir,
121124
is_cache_enabled,
@@ -386,24 +389,23 @@ def __init__(
386389
self._module_handle = None
387390
self._host_func_ptr = None
388391

389-
# Serialize MLIR module to text if needed
390-
# TODO: investigate why bytecode deserialization is not working
392+
# Serialize MLIR module to text if needed.
393+
# TODO: investigate why bytecode deserialization is not working.
391394
if isinstance(module, (bytes, str)):
392-
# Assume it's already MLIR text
393395
optimized_mlir = module.decode() if isinstance(module, bytes) else module
394396
else:
395-
# Serialize the MLIR module to text
396397
optimized_mlir = str(module)
397398

398-
# Get the execution engine instance and load the module
399399
from wave_lang.kernel.wave.execution_engine import get_execution_engine
400400

401401
if not create_execution_engine:
402402
return
403403
self._engine = get_execution_engine()
404404
self._module_handle = self._engine.load_module_from_text(optimized_mlir)
405+
self._bind_host_func()
405406

406-
# Look up the host wrapper function
407+
def _bind_host_func(self):
408+
"""Look up the host wrapper function and create a ctypes callable."""
407409
func_name = self.options.func_name
408410
try:
409411
self._host_func_ptr = self._engine.lookup(self._module_handle, func_name)
@@ -413,32 +415,49 @@ def __init__(
413415
f"Make sure the module was compiled with emit_host_func. Error: {e}"
414416
)
415417

416-
# Create ctypes function type
417-
# The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
418-
418+
# The host wrapper signature is:
419+
# void func(void* stream, PyObject* arg0, PyObject* arg1, ...).
419420
num_kernel_args = len(self.options.kernel_usages)
420-
arg_types = [ctypes.c_void_p] + [
421-
py_object
422-
] * num_kernel_args # +1 for stream pointer
421+
arg_types = [ctypes.c_void_p] + [py_object] * num_kernel_args
423422
func_type = ctypes.CFUNCTYPE(None, *arg_types)
424423
self._cfunc = func_type(self._host_func_ptr)
425424

425+
def dump_to_object_file(self, path: str):
426+
"""Dump the compiled host object file (with embedded GPU binary) to disk."""
427+
assert self._engine is not None, "no execution engine to dump from"
428+
self._engine.dump_to_object_file(path)
429+
430+
@classmethod
431+
def from_object_file(
432+
cls,
433+
options: WaveCompileOptions,
434+
object_file_path: str,
435+
mlir_asm: str = "",
436+
) -> "WaveKernelExecutionEngine":
437+
"""Load a cached object file instead of compiling from MLIR."""
438+
from wave_lang.kernel.wave.execution_engine import get_execution_engine
439+
440+
instance = cls.__new__(cls)
441+
instance.options = options
442+
instance.asm = mlir_asm
443+
instance._engine = get_execution_engine()
444+
instance._module_handle = instance._engine.load_from_object_file(
445+
object_file_path
446+
)
447+
instance._bind_host_func()
448+
return instance
449+
426450
def __call__(self, *args):
427451
return self.invoke(*args)
428452

429453
def invoke(self, *args) -> None:
430-
"""
431-
Invokes the wave kernel with the given arguments using the ExecutionEngine.
432-
"""
454+
"""Invoke the wave kernel with the given arguments using the ExecutionEngine."""
433455
assert (
434456
self._engine is not None
435-
), "Cannot invoke kernel without creating an execution engine. Revise the constructor call."
457+
), "Cannot invoke kernel without creating an execution engine."
436458

437-
# Get the current stream
438459
stream_ptr = torch.cuda.current_stream().cuda_stream
439-
440-
# Call the JIT-compiled host wrapper function
441-
# Signature: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
460+
# Signature: void func(void* stream, PyObject* arg0, PyObject* arg1, ...).
442461
self._cfunc(stream_ptr, *(py_object(arg) for arg in args))
443462

444463

@@ -1013,6 +1032,11 @@ def get_binary_path():
10131032
else:
10141033
return glob.glob(str(get_temp_binary_dir() / "*.hsaco"))[0]
10151034

1035+
def _get_water_object_cache_path(kernel_hash: str) -> Path:
1036+
"""Return the path for a cached Water object file."""
1037+
base = cache_manager.base_dir if cache_manager else get_cache_base_dir()
1038+
return base / kernel_hash / (kernel_hash + ".o")
1039+
10161040
# Create an indexing context and populate substitutions.
10171041
with IndexingContext() as idxc:
10181042
idxc.set_subs(options.subs)
@@ -1037,22 +1061,32 @@ def get_binary_path():
10371061
if cached_kernel:
10381062
options.kernel_usages = cached_kernel.kernel_sig
10391063
options.kernel_launch_info = cached_kernel.kernel_launch_info
1040-
if options.wave_runtime:
1041-
binary_path = get_binary_path()
10421064

10431065
if options.print_mlir:
10441066
print(cached_kernel.asm)
10451067

1046-
return cls(
1047-
options,
1048-
cached_kernel.vmfb,
1049-
cached_kernel.asm,
1050-
binary_path,
1051-
bound_scalar_symbols,
1052-
symbols_args_map,
1053-
None,
1054-
None,
1055-
)
1068+
if options.use_water_backend:
1069+
obj_path = _get_water_object_cache_path(options.kernel_hash)
1070+
if obj_path.exists():
1071+
return WaveKernelExecutionEngine.from_object_file(
1072+
options, str(obj_path), cached_kernel.asm
1073+
)
1074+
# Object file missing from cache, fall through
1075+
# to recompilation.
1076+
else:
1077+
if options.wave_runtime:
1078+
binary_path = get_binary_path()
1079+
1080+
return cls(
1081+
options,
1082+
cached_kernel.vmfb,
1083+
cached_kernel.asm,
1084+
binary_path,
1085+
bound_scalar_symbols,
1086+
symbols_args_map,
1087+
None,
1088+
None,
1089+
)
10561090

10571091
# For the wave runtime, we need the hsaco binary. So we turn on
10581092
# dumping of binaries and store in wave runtime directory. If we
@@ -1176,12 +1210,25 @@ def get_binary_path():
11761210
_compile_asm_to_binary(asm, options)
11771211
elif options.use_water_backend:
11781212
module = water_lowering_pipeline(mb.module_op, options)
1179-
return WaveKernelExecutionEngine(
1213+
engine = WaveKernelExecutionEngine(
11801214
options,
11811215
module,
11821216
asm,
11831217
create_execution_engine=not options.compile_to_mlir,
11841218
)
1219+
# Cache the compiled object file for future runs.
1220+
if (
1221+
is_cache_enabled()
1222+
and cache_manager is not None
1223+
and options.kernel_hash
1224+
and not debug_arg_info
1225+
and not options.compile_to_mlir
1226+
):
1227+
obj_path = _get_water_object_cache_path(options.kernel_hash)
1228+
obj_path.parent.mkdir(parents=True, exist_ok=True)
1229+
engine.dump_to_object_file(str(obj_path))
1230+
cache_manager.store_kernel(None, asm, options)
1231+
return engine
11851232
elif not options.compile_to_mlir:
11861233
# LLVM flow: only compile to VMFB when not in MLIR-only mode
11871234
compiled_wave_vmfb = compile_to_vmfb(asm, options)

wave_lang/kernel/wave/execution_engine/bindings.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,24 @@ NB_MODULE(wave_execution_engine, m) {
191191
192192
Raises:
193193
RuntimeError: If function lookup fails)")
194+
.def(
195+
"load_from_object_file",
196+
[](wave::ExecutionEngine &self, const std::string &filename) {
197+
return reinterpret_cast<uintptr_t>(
198+
unwrapExpected(self.loadFromObjectFile(filename),
199+
"Failed to load object file"));
200+
},
201+
nb::arg("filename"),
202+
R"(Load a pre-compiled object file into the execution engine.
203+
204+
Args:
205+
filename: Path to the object file
206+
207+
Returns:
208+
Module handle as integer
209+
210+
Raises:
211+
RuntimeError: If loading fails)")
194212
.def(
195213
"dump_to_object_file",
196214
[](wave::ExecutionEngine &self, const std::string &filename) {

wave_lang/kernel/wave/execution_engine/execution_engine.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,49 @@ wave::ExecutionEngine::lookup(wave::ExecutionEngine::ModuleHandle handle,
368368
return makeStringError("looked up function is null");
369369
}
370370

371+
llvm::Expected<wave::ExecutionEngine::ModuleHandle>
372+
wave::ExecutionEngine::loadFromObjectFile(llvm::StringRef filename) {
373+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
374+
llvm::MemoryBuffer::getFile(filename);
375+
if (!fileOrErr)
376+
return makeStringError("could not open object file '" + filename +
377+
"': " + fileOrErr.getError().message());
378+
379+
// Create a unique JITDylib for this object.
380+
llvm::orc::JITDylib *dylib = nullptr;
381+
while (true) {
382+
std::string uniqueName =
383+
(llvm::Twine("module") + llvm::Twine(uniqueNameCounter++)).str();
384+
if (jit->getJITDylibByName(uniqueName))
385+
continue;
386+
387+
llvm::Expected<llvm::orc::JITDylib &> res =
388+
jit->createJITDylib(std::move(uniqueName));
389+
if (!res)
390+
return res.takeError();
391+
392+
dylib = &res.get();
393+
break;
394+
}
395+
assert(dylib && "failed to create JITDylib");
396+
397+
const llvm::DataLayout &dataLayout = jit->getDataLayout();
398+
dylib->addGenerator(
399+
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
400+
dataLayout.getGlobalPrefix())));
401+
402+
if (symbolMap)
403+
cantFail(
404+
dylib->define(absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
405+
dylib->getExecutionSession(), dataLayout)))));
406+
407+
// Use cantFail here because recovering from a partially loaded dylib would
408+
// leave the execution engine in an inconsistent state.
409+
llvm::cantFail(jit->addObjectFile(*dylib, std::move(fileOrErr.get())));
410+
llvm::cantFail(jit->initialize(*dylib));
411+
return static_cast<ModuleHandle>(dylib);
412+
}
413+
371414
llvm::Error wave::ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) {
372415
if (cache == nullptr)
373416
return makeStringError("cannot dump ExecutionEngine object code to file: "

wave_lang/kernel/wave/execution_engine/execution_engine.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class ExecutionEngine {
9696
llvm::Expected<void *> lookup(ModuleHandle handle,
9797
llvm::StringRef name) const;
9898

99+
/// Load a pre-compiled object file into the execution engine.
100+
llvm::Expected<ModuleHandle> loadFromObjectFile(llvm::StringRef filename);
101+
99102
/// Dump object code to output file `filename`.
100103
llvm::Error dumpToObjectFile(llvm::StringRef filename);
101104

wave_lang/kernel/wave/execution_engine/execution_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _create_options_from_env() -> "ExecutionEngineOptions":
158158
def _env_enabled(var: str, default: str = "0") -> bool:
159159
return bool(int(os.environ.get(var, default)))
160160

161-
options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE")
161+
options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE", "1")
162162
options.enable_gdb_notification_listener = _env_enabled("WAVE_ENABLE_GDB_LISTENER")
163163
options.enable_perf_notification_listener = _env_enabled(
164164
"WAVE_ENABLE_PERF_LISTENER"

0 commit comments

Comments
 (0)