Skip to content

Commit 9660f5d

Browse files
Hardcode84claude
andcommitted
Add disk caching for Water backend via object files
After JIT compilation, dump the host object file (which embeds the GPU binary) to the cache directory. On cache hit, load the .o directly into the execution engine, skipping MLIR parsing, LLVM IR translation, and host code compilation entirely. - Add ExecutionEngine::loadFromObjectFile (C++ + Python binding). - Enable object cache by default so dump_to_object_file works. - Include use_water_backend in cache hash to avoid collisions. - Add testWaterBackendCache covering miss, store, hit, and correctness. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
1 parent bceb332 commit 9660f5d

File tree

7 files changed

+223
-16
lines changed

7 files changed

+223
-16
lines changed

tests/kernel/runtime/cache_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
require_cdna_3_or_4,
5252
require_rdna4,
5353
require_e2e,
54+
require_water_and_ee,
5455
)
5556

5657
require_cache = pytest.mark.skipif(
@@ -868,3 +869,87 @@ def simple_copy(
868869
assert kernel2.gpu_binary_path.endswith(
869870
".hsaco"
870871
), "Expected .hsaco extension for cached kernel"
872+
873+
874+
@require_e2e
875+
@require_cache
876+
@require_water_and_ee
877+
def testWaterBackendCache(tmp_path):
878+
"""Test that Water backend object file caching works correctly."""
879+
880+
reset_cache_manager(tmp_path)
881+
882+
M = tkl.sym.M
883+
N = tkl.sym.N
884+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
885+
886+
wave_size = 64
887+
BLOCK_M = 1
888+
BLOCK_N = 256
889+
890+
constraints: list[tkw.Constraint] = [
891+
tkw.HardwareConstraint(
892+
threads_per_wave=wave_size,
893+
vector_shapes={M: BLOCK_M, N: BLOCK_N},
894+
),
895+
tkw.WorkgroupConstraint(M, BLOCK_M, 1),
896+
tkw.WorkgroupConstraint(N, BLOCK_N, 0),
897+
tkw.WaveConstraint(M, BLOCK_M),
898+
tkw.WaveConstraint(N, BLOCK_N),
899+
]
900+
901+
@tkw.wave(constraints)
902+
def simple_copy(
903+
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
904+
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
905+
):
906+
res = tkw.read(a)
907+
tkw.write(res, b)
908+
909+
hyperparams = {
910+
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
911+
M: 16,
912+
N: 256,
913+
}
914+
915+
cache_manager = get_cache_manager()
916+
917+
options = WaveCompileOptions(
918+
subs=copy.deepcopy(hyperparams),
919+
canonicalize=True,
920+
use_water_backend=True,
921+
)
922+
options = set_default_run_config(options)
923+
924+
# Before compilation, nothing in cache.
925+
assert len(cache_manager.session_cache) == 0, "Expected empty cache at start."
926+
927+
# First compilation -- cache miss, should produce and store an object file.
928+
kernel1 = wave_compile(options, simple_copy)
929+
930+
assert (
931+
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0
932+
), "Expected first compilation to be a cache miss."
933+
assert len(cache_manager.session_cache) == 1, "Expected one entry in session cache."
934+
935+
# Verify object file was written to cache directory.
936+
kernel_hash = options.kernel_hash
937+
obj_path = tmp_path / kernel_hash / (kernel_hash + ".o")
938+
assert obj_path.exists(), f"Expected object file at {obj_path}."
939+
assert obj_path.stat().st_size > 0, "Object file should not be empty."
940+
941+
a = device_randn(16, 256, dtype=torch.float16)
942+
b = device_zeros(16, 256, dtype=torch.float16)
943+
kernel1(a, b)
944+
assert_close(a, b)
945+
946+
# Second compilation -- cache hit, should load from object file.
947+
b2 = device_zeros(16, 256, dtype=torch.float16)
948+
kernel2 = wave_compile(options, simple_copy)
949+
950+
assert (
951+
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 1
952+
), "Expected second compilation to be a cache hit."
953+
954+
kernel2(a, b2)
955+
assert_close(a, b2)

wave_lang/kernel/wave/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def get_hash(
214214
options.reorder_allocs,
215215
options.override_schedule,
216216
options.use_bound_check,
217+
options.use_water_backend,
217218
]
218219

219220
# Add kernel/helper function specific hashes.

wave_lang/kernel/wave/compile.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@
115115
from ..compiler import host_codegen, kernel_codegen, builder, dispatch_codegen
116116
from ..compiler.wave_codegen import WaveEmitter
117117
from .compile_options import WaveCompileOptions
118+
from pathlib import Path
119+
118120
from .cache import (
121+
get_cache_base_dir,
119122
get_cache_manager,
120123
get_temp_binary_dir,
121124
is_cache_enabled,
@@ -402,8 +405,10 @@ def __init__(
402405
return
403406
self._engine = get_execution_engine()
404407
self._module_handle = self._engine.load_module_from_text(optimized_mlir)
408+
self._bind_host_func()
405409

406-
# Look up the host wrapper function
410+
def _bind_host_func(self):
411+
"""Look up the host wrapper function and create a ctypes callable."""
407412
func_name = self.options.func_name
408413
try:
409414
self._host_func_ptr = self._engine.lookup(self._module_handle, func_name)
@@ -415,14 +420,38 @@ def __init__(
415420

416421
# Create ctypes function type
417422
# The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
418-
419423
num_kernel_args = len(self.options.kernel_usages)
420424
arg_types = [ctypes.c_void_p] + [
421425
py_object
422426
] * num_kernel_args # +1 for stream pointer
423427
func_type = ctypes.CFUNCTYPE(None, *arg_types)
424428
self._cfunc = func_type(self._host_func_ptr)
425429

430+
def dump_to_object_file(self, path: str):
431+
"""Dump the compiled host object file (with embedded GPU binary) to disk."""
432+
assert self._engine is not None, "no execution engine to dump from"
433+
self._engine.dump_to_object_file(path)
434+
435+
@classmethod
436+
def from_object_file(
437+
cls,
438+
options: WaveCompileOptions,
439+
object_file_path: str,
440+
mlir_asm: str = "",
441+
) -> "WaveKernelExecutionEngine":
442+
"""Load a cached object file instead of compiling from MLIR."""
443+
from wave_lang.kernel.wave.execution_engine import get_execution_engine
444+
445+
instance = cls.__new__(cls)
446+
instance.options = options
447+
instance.asm = mlir_asm
448+
instance._engine = get_execution_engine()
449+
instance._module_handle = instance._engine.load_from_object_file(
450+
object_file_path
451+
)
452+
instance._bind_host_func()
453+
return instance
454+
426455
def __call__(self, *args):
427456
return self.invoke(*args)
428457

@@ -1013,6 +1042,11 @@ def get_binary_path():
10131042
else:
10141043
return glob.glob(str(get_temp_binary_dir() / "*.hsaco"))[0]
10151044

1045+
def _get_water_object_cache_path(kernel_hash: str) -> Path:
1046+
"""Return the path for a cached Water object file."""
1047+
base = cache_manager.base_dir if cache_manager else get_cache_base_dir()
1048+
return base / kernel_hash / (kernel_hash + ".o")
1049+
10161050
# Create an indexing context and populate substitutions.
10171051
with IndexingContext() as idxc:
10181052
idxc.set_subs(options.subs)
@@ -1037,22 +1071,32 @@ def get_binary_path():
10371071
if cached_kernel:
10381072
options.kernel_usages = cached_kernel.kernel_sig
10391073
options.kernel_launch_info = cached_kernel.kernel_launch_info
1040-
if options.wave_runtime:
1041-
binary_path = get_binary_path()
10421074

10431075
if options.print_mlir:
10441076
print(cached_kernel.asm)
10451077

1046-
return cls(
1047-
options,
1048-
cached_kernel.vmfb,
1049-
cached_kernel.asm,
1050-
binary_path,
1051-
bound_scalar_symbols,
1052-
symbols_args_map,
1053-
None,
1054-
None,
1055-
)
1078+
if options.use_water_backend:
1079+
obj_path = _get_water_object_cache_path(options.kernel_hash)
1080+
if obj_path.exists():
1081+
return WaveKernelExecutionEngine.from_object_file(
1082+
options, str(obj_path), cached_kernel.asm
1083+
)
1084+
# Object file missing from cache, fall through
1085+
# to recompilation.
1086+
else:
1087+
if options.wave_runtime:
1088+
binary_path = get_binary_path()
1089+
1090+
return cls(
1091+
options,
1092+
cached_kernel.vmfb,
1093+
cached_kernel.asm,
1094+
binary_path,
1095+
bound_scalar_symbols,
1096+
symbols_args_map,
1097+
None,
1098+
None,
1099+
)
10561100

10571101
# For the wave runtime, we need the hsaco binary. So we turn on
10581102
# dumping of binaries and store in wave runtime directory. If we
@@ -1176,12 +1220,25 @@ def get_binary_path():
11761220
_compile_asm_to_binary(asm, options)
11771221
elif options.use_water_backend:
11781222
module = water_lowering_pipeline(mb.module_op, options)
1179-
return WaveKernelExecutionEngine(
1223+
engine = WaveKernelExecutionEngine(
11801224
options,
11811225
module,
11821226
asm,
11831227
create_execution_engine=not options.compile_to_mlir,
11841228
)
1229+
# Cache the compiled object file for future runs.
1230+
if (
1231+
is_cache_enabled()
1232+
and cache_manager is not None
1233+
and options.kernel_hash
1234+
and not debug_arg_info
1235+
and not options.compile_to_mlir
1236+
):
1237+
obj_path = _get_water_object_cache_path(options.kernel_hash)
1238+
obj_path.parent.mkdir(parents=True, exist_ok=True)
1239+
engine.dump_to_object_file(str(obj_path))
1240+
cache_manager.store_kernel(None, asm, options)
1241+
return engine
11851242
elif not options.compile_to_mlir:
11861243
# LLVM flow: only compile to VMFB when not in MLIR-only mode
11871244
compiled_wave_vmfb = compile_to_vmfb(asm, options)

wave_lang/kernel/wave/execution_engine/bindings.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,24 @@ NB_MODULE(wave_execution_engine, m) {
191191
192192
Raises:
193193
RuntimeError: If function lookup fails)")
194+
.def(
195+
"load_from_object_file",
196+
[](wave::ExecutionEngine &self, const std::string &filename) {
197+
return reinterpret_cast<uintptr_t>(
198+
unwrapExpected(self.loadFromObjectFile(filename),
199+
"Failed to load object file"));
200+
},
201+
nb::arg("filename"),
202+
R"(Load a pre-compiled object file into the execution engine.
203+
204+
Args:
205+
filename: Path to the object file
206+
207+
Returns:
208+
Module handle as integer
209+
210+
Raises:
211+
RuntimeError: If loading fails)")
194212
.def(
195213
"dump_to_object_file",
196214
[](wave::ExecutionEngine &self, const std::string &filename) {

wave_lang/kernel/wave/execution_engine/execution_engine.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,49 @@ wave::ExecutionEngine::lookup(wave::ExecutionEngine::ModuleHandle handle,
368368
return makeStringError("looked up function is null");
369369
}
370370

371+
llvm::Expected<wave::ExecutionEngine::ModuleHandle>
372+
wave::ExecutionEngine::loadFromObjectFile(llvm::StringRef filename) {
373+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
374+
llvm::MemoryBuffer::getFile(filename);
375+
if (!fileOrErr)
376+
return makeStringError("could not open object file '" + filename +
377+
"': " + fileOrErr.getError().message());
378+
379+
// Create a unique JITDylib for this object.
380+
llvm::orc::JITDylib *dylib = nullptr;
381+
while (true) {
382+
std::string uniqueName =
383+
(llvm::Twine("module") + llvm::Twine(uniqueNameCounter++)).str();
384+
if (jit->getJITDylibByName(uniqueName))
385+
continue;
386+
387+
llvm::Expected<llvm::orc::JITDylib &> res =
388+
jit->createJITDylib(std::move(uniqueName));
389+
if (!res)
390+
return res.takeError();
391+
392+
dylib = &res.get();
393+
break;
394+
}
395+
assert(dylib && "failed to create JITDylib");
396+
397+
const llvm::DataLayout &dataLayout = jit->getDataLayout();
398+
dylib->addGenerator(
399+
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
400+
dataLayout.getGlobalPrefix())));
401+
402+
if (symbolMap)
403+
cantFail(
404+
dylib->define(absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
405+
dylib->getExecutionSession(), dataLayout)))));
406+
407+
// Use cantFail here because recovering from a partially loaded dylib would
408+
// leave the execution engine in an inconsistent state.
409+
llvm::cantFail(jit->addObjectFile(*dylib, std::move(fileOrErr.get())));
410+
llvm::cantFail(jit->initialize(*dylib));
411+
return static_cast<ModuleHandle>(dylib);
412+
}
413+
371414
llvm::Error wave::ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) {
372415
if (cache == nullptr)
373416
return makeStringError("cannot dump ExecutionEngine object code to file: "

wave_lang/kernel/wave/execution_engine/execution_engine.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class ExecutionEngine {
9696
llvm::Expected<void *> lookup(ModuleHandle handle,
9797
llvm::StringRef name) const;
9898

99+
/// Load a pre-compiled object file into the execution engine.
100+
llvm::Expected<ModuleHandle> loadFromObjectFile(llvm::StringRef filename);
101+
99102
/// Dump object code to output file `filename`.
100103
llvm::Error dumpToObjectFile(llvm::StringRef filename);
101104

wave_lang/kernel/wave/execution_engine/execution_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _create_options_from_env() -> "ExecutionEngineOptions":
158158
def _env_enabled(var: str, default: str = "0") -> bool:
159159
return bool(int(os.environ.get(var, default)))
160160

161-
options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE")
161+
options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE", "1")
162162
options.enable_gdb_notification_listener = _env_enabled("WAVE_ENABLE_GDB_LISTENER")
163163
options.enable_perf_notification_listener = _env_enabled(
164164
"WAVE_ENABLE_PERF_LISTENER"

0 commit comments

Comments
 (0)