Skip to content

Commit 2ab2608

Browse files
author
Diptorup Deb
committed
Enable overload caching in kernel_dispatcher.
- Use function qualname as the entry point (key) for caching compiled function in the target's _defns dict. - KernelCompileResult now extends numba's CompileResult class and only stores the spirv bitcode as an extra field. - Added a new function to return just the device ir for a cached overload. - Updates to launcher to accomodate changes to kernel_dispatcher.
1 parent 185a61b commit 2ab2608

File tree

2 files changed

+48
-24
lines changed

2 files changed

+48
-24
lines changed

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])
3030

3131
_KernelCompileResult = namedtuple(
32-
"_KernelCompileResult",
33-
["status", "cres_or_error", "entry_point"],
32+
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
3433
)
3534

3635

@@ -96,11 +95,11 @@ def _compile_to_spirv(
9695
)
9796

9897
def compile(self, args, return_type):
99-
kcres = self._compile_cached(args, return_type)
100-
if kcres.status:
98+
status, kcres = self._compile_cached(args, return_type)
99+
if status:
101100
return kcres
102101

103-
raise kcres.cres_or_error
102+
raise kcres
104103

105104
def _compile_cached(
106105
self, args, return_type: types.Type
@@ -137,34 +136,45 @@ def _compile_cached(
137136
"""
138137
key = tuple(args), return_type
139138
try:
140-
return _KernelCompileResult(False, self._failed_cache[key], None)
139+
return False, self._failed_cache[key]
141140
except KeyError:
142141
pass
143142

144143
try:
145-
kernel_cres: CompileResult = self._compile_core(args, return_type)
144+
cres: CompileResult = self._compile_core(args, return_type)
146145

147-
kernel_library = kernel_cres.library
148-
kernel_fndesc = kernel_cres.fndesc
149-
kernel_targetctx = kernel_cres.target_context
150-
151-
kernel_module = self._compile_to_spirv(
152-
kernel_library, kernel_fndesc, kernel_targetctx
146+
kernel_device_ir_module = self._compile_to_spirv(
147+
cres.library, cres.fndesc, cres.target_context
153148
)
154149

150+
kcres_attrs = []
151+
152+
for cres_field in cres._fields:
153+
cres_attr = getattr(cres, cres_field)
154+
if cres_field == "entry_point":
155+
if cres_attr is not None:
156+
raise AssertionError(
157+
"Compiled kernel and device_func should be "
158+
"compiled with compile_cfunc option turned off"
159+
)
160+
cres_attr = cres.fndesc.qualname
161+
kcres_attrs.append(cres_attr)
162+
163+
kcres_attrs.append(kernel_device_ir_module)
164+
155165
if config.DUMP_KERNEL_LLVM:
156166
with open(
157-
kernel_cres.fndesc.llvm_func_name + ".ll",
167+
cres.fndesc.llvm_func_name + ".ll",
158168
"w",
159169
encoding="UTF-8",
160170
) as f:
161-
f.write(kernel_cres.library.final_module)
171+
f.write(cres.library.final_module)
162172

163173
except errors.TypingError as e:
164174
self._failed_cache[key] = e
165-
return _KernelCompileResult(False, e, None)
175+
return False, e
166176

167-
return _KernelCompileResult(True, kernel_cres, kernel_module)
177+
return True, _KernelCompileResult(*kcres_attrs)
168178

169179

170180
class KernelDispatcher(Dispatcher):
@@ -234,7 +244,14 @@ def typeof_pyval(self, val):
234244

235245
def add_overload(self, cres):
236246
args = tuple(cres.signature.args)
237-
self.overloads[args] = cres.entry_point
247+
self.overloads[args] = cres
248+
249+
def get_overload_device_ir(self, sig):
250+
"""
251+
Return the compiled device bitcode for the given signature.
252+
"""
253+
args, _ = sigutils.normalize_signature(sig)
254+
return self.overloads[tuple(args)].kernel_device_ir_module
238255

239256
def compile(self, sig) -> _KernelCompileResult:
240257
disp = self._get_dispatcher_for_current_target()
@@ -274,7 +291,7 @@ def cb_llvm(dur):
274291
# Don't recompile if signature already exists
275292
existing = self.overloads.get(tuple(args))
276293
if existing is not None:
277-
return existing
294+
return existing.entry_point
278295

279296
# TODO: Enable caching
280297
# Add code to enable on disk caching of a binary spirv kernel.
@@ -298,7 +315,11 @@ def folded(args, kws):
298315
)[1]
299316

300317
raise e.bind_fold_arguments(folded)
301-
self.add_overload(kcres.cres_or_error)
318+
self.add_overload(kcres)
319+
320+
kcres.target_context.insert_user_function(
321+
kcres.entry_point, kcres.fndesc, [kcres.library]
322+
)
302323

303324
# TODO: enable caching of kernel_module
304325
# https://github.com/IntelPython/numba-dpex/issues/1197

numba_dpex/experimental/launcher.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,10 @@ def intrin_launch_trampoline(
313313
sig = types.void(kernel_fn, index_space, kernel_args)
314314
# signature of the kernel_fn
315315
kernel_sig = types.void(*kernel_args_list)
316-
kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig)
316+
kernel_fn.dispatcher.compile(kernel_sig)
317+
kernel_module: _KernelModule = kernel_fn.dispatcher.get_overload_device_ir(
318+
kernel_sig
319+
)
317320
kernel_targetctx = kernel_fn.dispatcher.targetctx
318321

319322
def codegen(cgctx, builder, sig, llargs):
@@ -329,7 +332,7 @@ def codegen(cgctx, builder, sig, llargs):
329332
)
330333

331334
kernel_bc_byte_str = fn_body_gen.insert_kernel_bitcode_as_byte_str(
332-
kmodule
335+
kernel_module
333336
)
334337

335338
populated_kernel_args = (
@@ -346,10 +349,10 @@ def codegen(cgctx, builder, sig, llargs):
346349
kbref = fn_body_gen.create_kernel_bundle_from_spirv(
347350
queue_ref=qref,
348351
kernel_bc=kernel_bc_byte_str,
349-
kernel_bc_size_in_bytes=len(kmodule.kernel_bitcode),
352+
kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode),
350353
)
351354

352-
kref = fn_body_gen.get_kernel(kmodule, kbref)
355+
kref = fn_body_gen.get_kernel(kernel_module, kbref)
353356

354357
index_space_values = fn_body_gen.create_llvm_values_for_index_space(
355358
indexer_argty=sig.args[1],

0 commit comments

Comments
 (0)