Skip to content

Commit 1f97430

Browse files
Hardcode84claude
andcommitted
Add disk caching for Water backend via object files
After JIT compilation, dump the host object file (which embeds the GPU binary) to the cache directory. On cache hit, load the .o directly into the execution engine, skipping MLIR parsing, LLVM IR translation, and host code compilation entirely. - Add ExecutionEngine::loadFromObjectFile (C++ + Python binding). - Enable object cache by default so dump_to_object_file works. - Include use_water_backend in cache hash to avoid collisions. - Add testWaterBackendCache covering miss, store, hit, and correctness. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
1 parent 6560184 commit 1f97430

File tree

7 files changed

+223
-16
lines changed

7 files changed

+223
-16
lines changed

tests/kernel/runtime/cache_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
require_cdna_3_or_4,
5252
require_rdna4,
5353
require_e2e,
54+
require_water_and_ee,
5455
)
5556

5657
require_cache = pytest.mark.skipif(
@@ -872,3 +873,87 @@ def simple_copy(
872873
assert kernel2.gpu_binary_path.endswith(
873874
".hsaco"
874875
), "Expected .hsaco extension for cached kernel"
876+
877+
878+
@require_e2e
879+
@require_cache
880+
@require_water_and_ee
881+
def testWaterBackendCache(tmp_path):
882+
"""Test that Water backend object file caching works correctly."""
883+
884+
reset_cache_manager(tmp_path)
885+
886+
M = tkl.sym.M
887+
N = tkl.sym.N
888+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
889+
890+
wave_size = 64
891+
BLOCK_M = 1
892+
BLOCK_N = 256
893+
894+
constraints: list[tkw.Constraint] = [
895+
tkw.HardwareConstraint(
896+
threads_per_wave=wave_size,
897+
vector_shapes={M: BLOCK_M, N: BLOCK_N},
898+
),
899+
tkw.WorkgroupConstraint(M, BLOCK_M, 1),
900+
tkw.WorkgroupConstraint(N, BLOCK_N, 0),
901+
tkw.WaveConstraint(M, BLOCK_M),
902+
tkw.WaveConstraint(N, BLOCK_N),
903+
]
904+
905+
@tkw.wave(constraints)
906+
def simple_copy(
907+
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
908+
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
909+
):
910+
res = tkw.read(a)
911+
tkw.write(res, b)
912+
913+
hyperparams = {
914+
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
915+
M: 16,
916+
N: 256,
917+
}
918+
919+
cache_manager = get_cache_manager()
920+
921+
options = WaveCompileOptions(
922+
subs=copy.deepcopy(hyperparams),
923+
canonicalize=True,
924+
use_water_backend=True,
925+
)
926+
options = set_default_run_config(options)
927+
928+
# Before compilation, nothing in cache.
929+
assert len(cache_manager.session_cache) == 0, "Expected empty cache at start."
930+
931+
# First compilation -- cache miss, should produce and store an object file.
932+
kernel1 = wave_compile(options, simple_copy)
933+
934+
assert (
935+
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 0
936+
), "Expected first compilation to be a cache miss."
937+
assert len(cache_manager.session_cache) == 1, "Expected one entry in session cache."
938+
939+
# Verify object file was written to cache directory.
940+
kernel_hash = options.kernel_hash
941+
obj_path = tmp_path / kernel_hash / (kernel_hash + ".o")
942+
assert obj_path.exists(), f"Expected object file at {obj_path}."
943+
assert obj_path.stat().st_size > 0, "Object file should not be empty."
944+
945+
a = device_randn(16, 256, dtype=torch.float16)
946+
b = device_zeros(16, 256, dtype=torch.float16)
947+
kernel1(a, b)
948+
assert_close(a, b)
949+
950+
# Second compilation -- cache hit, should load from object file.
951+
b2 = device_zeros(16, 256, dtype=torch.float16)
952+
kernel2 = wave_compile(options, simple_copy)
953+
954+
assert (
955+
cache_manager.cache_misses == 1 and cache_manager.cache_hits == 1
956+
), "Expected second compilation to be a cache hit."
957+
958+
kernel2(a, b2)
959+
assert_close(a, b2)

wave_lang/kernel/wave/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def get_hash(
214214
options.reorder_allocs,
215215
options.override_schedule,
216216
options.use_bound_check,
217+
options.use_water_backend,
217218
]
218219

219220
# Add kernel/helper function specific hashes.

wave_lang/kernel/wave/compile.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@
123123
from ..compiler import host_codegen, kernel_codegen, builder, dispatch_codegen
124124
from ..compiler.wave_codegen import WaveEmitter
125125
from .compile_options import WaveCompileOptions
126+
from pathlib import Path
127+
126128
from .cache import (
129+
get_cache_base_dir,
127130
get_cache_manager,
128131
get_temp_binary_dir,
129132
is_cache_enabled,
@@ -414,8 +417,10 @@ def __init__(
414417
return
415418
self._engine = get_execution_engine()
416419
self._module_handle = self._engine.load_module_from_text(optimized_mlir)
420+
self._bind_host_func()
417421

418-
# Look up the host wrapper function
422+
def _bind_host_func(self):
423+
"""Look up the host wrapper function and create a ctypes callable."""
419424
func_name = self.options.func_name
420425
try:
421426
self._host_func_ptr = self._engine.lookup(self._module_handle, func_name)
@@ -427,14 +432,38 @@ def __init__(
427432

428433
# Create ctypes function type
429434
# The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...)
430-
431435
num_kernel_args = len(self.options.kernel_usages)
432436
arg_types = [ctypes.c_void_p] + [
433437
py_object
434438
] * num_kernel_args # +1 for stream pointer
435439
func_type = ctypes.CFUNCTYPE(None, *arg_types)
436440
self._cfunc = func_type(self._host_func_ptr)
437441

442+
def dump_to_object_file(self, path: str):
443+
"""Dump the compiled host object file (with embedded GPU binary) to disk."""
444+
assert self._engine is not None, "no execution engine to dump from"
445+
self._engine.dump_to_object_file(path)
446+
447+
@classmethod
448+
def from_object_file(
449+
cls,
450+
options: WaveCompileOptions,
451+
object_file_path: str,
452+
mlir_asm: str = "",
453+
) -> "WaveKernelExecutionEngine":
454+
"""Load a cached object file instead of compiling from MLIR."""
455+
from wave_lang.kernel.wave.execution_engine import get_execution_engine
456+
457+
instance = cls.__new__(cls)
458+
instance.options = options
459+
instance.asm = mlir_asm
460+
instance._engine = get_execution_engine()
461+
instance._module_handle = instance._engine.load_from_object_file(
462+
object_file_path
463+
)
464+
instance._bind_host_func()
465+
return instance
466+
438467
def __call__(self, *args):
439468
return self.invoke(*args)
440469

@@ -1034,6 +1063,11 @@ def get_binary_path():
10341063
else:
10351064
return glob.glob(str(get_temp_binary_dir() / "*.hsaco"))[0]
10361065

1066+
def _get_water_object_cache_path(kernel_hash: str) -> Path:
1067+
"""Return the path for a cached Water object file."""
1068+
base = cache_manager.base_dir if cache_manager else get_cache_base_dir()
1069+
return base / kernel_hash / (kernel_hash + ".o")
1070+
10371071
# Create an indexing context and populate substitutions.
10381072
with IndexingContext() as idxc:
10391073
idxc.set_subs(options.subs)
@@ -1058,22 +1092,32 @@ def get_binary_path():
10581092
if cached_kernel:
10591093
options.kernel_usages = cached_kernel.kernel_sig
10601094
options.kernel_launch_info = cached_kernel.kernel_launch_info
1061-
if options.wave_runtime:
1062-
binary_path = get_binary_path()
10631095

10641096
if options.print_mlir:
10651097
print(cached_kernel.asm)
10661098

1067-
return cls(
1068-
options,
1069-
cached_kernel.vmfb,
1070-
cached_kernel.asm,
1071-
binary_path,
1072-
bound_scalar_symbols,
1073-
symbols_args_map,
1074-
None,
1075-
None,
1076-
)
1099+
if options.use_water_backend:
1100+
obj_path = _get_water_object_cache_path(options.kernel_hash)
1101+
if obj_path.exists():
1102+
return WaveKernelExecutionEngine.from_object_file(
1103+
options, str(obj_path), cached_kernel.asm
1104+
)
1105+
# Object file missing from cache, fall through
1106+
# to recompilation.
1107+
else:
1108+
if options.wave_runtime:
1109+
binary_path = get_binary_path()
1110+
1111+
return cls(
1112+
options,
1113+
cached_kernel.vmfb,
1114+
cached_kernel.asm,
1115+
binary_path,
1116+
bound_scalar_symbols,
1117+
symbols_args_map,
1118+
None,
1119+
None,
1120+
)
10771121

10781122
# For the wave runtime, we need the hsaco binary. So we turn on
10791123
# dumping of binaries and store in wave runtime directory. If we
@@ -1210,12 +1254,25 @@ def get_binary_path():
12101254
_compile_asm_to_binary(asm, options)
12111255
elif options.use_water_backend:
12121256
module = water_lowering_pipeline(mb.module_op, options)
1213-
return WaveKernelExecutionEngine(
1257+
engine = WaveKernelExecutionEngine(
12141258
options,
12151259
module,
12161260
asm,
12171261
create_execution_engine=not options.compile_to_mlir,
12181262
)
1263+
# Cache the compiled object file for future runs.
1264+
if (
1265+
is_cache_enabled()
1266+
and cache_manager is not None
1267+
and options.kernel_hash
1268+
and not debug_arg_info
1269+
and not options.compile_to_mlir
1270+
):
1271+
obj_path = _get_water_object_cache_path(options.kernel_hash)
1272+
obj_path.parent.mkdir(parents=True, exist_ok=True)
1273+
engine.dump_to_object_file(str(obj_path))
1274+
cache_manager.store_kernel(None, asm, options)
1275+
return engine
12191276
elif not options.compile_to_mlir:
12201277
# LLVM flow: only compile to VMFB when not in MLIR-only mode
12211278
compiled_wave_vmfb = compile_to_vmfb(asm, options)

wave_lang/kernel/wave/execution_engine/bindings.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,24 @@ NB_MODULE(wave_execution_engine, m) {
191191
192192
Raises:
193193
RuntimeError: If function lookup fails)")
194+
.def(
195+
"load_from_object_file",
196+
[](wave::ExecutionEngine &self, const std::string &filename) {
197+
return reinterpret_cast<uintptr_t>(
198+
unwrapExpected(self.loadFromObjectFile(filename),
199+
"Failed to load object file"));
200+
},
201+
nb::arg("filename"),
202+
R"(Load a pre-compiled object file into the execution engine.
203+
204+
Args:
205+
filename: Path to the object file
206+
207+
Returns:
208+
Module handle as integer
209+
210+
Raises:
211+
RuntimeError: If loading fails)")
194212
.def(
195213
"dump_to_object_file",
196214
[](wave::ExecutionEngine &self, const std::string &filename) {

wave_lang/kernel/wave/execution_engine/execution_engine.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,49 @@ wave::ExecutionEngine::lookup(wave::ExecutionEngine::ModuleHandle handle,
368368
return makeStringError("looked up function is null");
369369
}
370370

371+
llvm::Expected<wave::ExecutionEngine::ModuleHandle>
372+
wave::ExecutionEngine::loadFromObjectFile(llvm::StringRef filename) {
373+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
374+
llvm::MemoryBuffer::getFile(filename);
375+
if (!fileOrErr)
376+
return makeStringError("could not open object file '" + filename +
377+
"': " + fileOrErr.getError().message());
378+
379+
// Create a unique JITDylib for this object.
380+
llvm::orc::JITDylib *dylib = nullptr;
381+
while (true) {
382+
std::string uniqueName =
383+
(llvm::Twine("module") + llvm::Twine(uniqueNameCounter++)).str();
384+
if (jit->getJITDylibByName(uniqueName))
385+
continue;
386+
387+
llvm::Expected<llvm::orc::JITDylib &> res =
388+
jit->createJITDylib(std::move(uniqueName));
389+
if (!res)
390+
return res.takeError();
391+
392+
dylib = &res.get();
393+
break;
394+
}
395+
assert(dylib && "failed to create JITDylib");
396+
397+
const llvm::DataLayout &dataLayout = jit->getDataLayout();
398+
dylib->addGenerator(
399+
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
400+
dataLayout.getGlobalPrefix())));
401+
402+
if (symbolMap)
403+
cantFail(
404+
dylib->define(absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
405+
dylib->getExecutionSession(), dataLayout)))));
406+
407+
// Use cantFail here because recovering from a partially loaded dylib would
408+
// leave the execution engine in an inconsistent state.
409+
llvm::cantFail(jit->addObjectFile(*dylib, std::move(fileOrErr.get())));
410+
llvm::cantFail(jit->initialize(*dylib));
411+
return static_cast<ModuleHandle>(dylib);
412+
}
413+
371414
llvm::Error wave::ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) {
372415
if (cache == nullptr)
373416
return makeStringError("cannot dump ExecutionEngine object code to file: "

wave_lang/kernel/wave/execution_engine/execution_engine.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class ExecutionEngine {
9696
llvm::Expected<void *> lookup(ModuleHandle handle,
9797
llvm::StringRef name) const;
9898

99+
/// Load a pre-compiled object file into the execution engine.
100+
llvm::Expected<ModuleHandle> loadFromObjectFile(llvm::StringRef filename);
101+
99102
/// Dump object code to output file `filename`.
100103
llvm::Error dumpToObjectFile(llvm::StringRef filename);
101104

wave_lang/kernel/wave/execution_engine/execution_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _create_options_from_env() -> "ExecutionEngineOptions":
158158
def _env_enabled(var: str, default: str = "0") -> bool:
159159
return bool(int(os.environ.get(var, default)))
160160

161-
options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE")
161+
options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE", "1")
162162
options.enable_gdb_notification_listener = _env_enabled("WAVE_ENABLE_GDB_LISTENER")
163163
options.enable_perf_notification_listener = _env_enabled(
164164
"WAVE_ENABLE_PERF_LISTENER"

0 commit comments

Comments
 (0)