Skip to content

Commit 7eb3e9a

Browse files
committed
Add typing for experimental kernel dispatcher
1 parent 3f8cdf9 commit 7eb3e9a

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,33 @@
1010
from typing import Tuple
1111

1212
import numba.core.event as ev
13+
from llvmlite.binding.value import ValueRef
1314
from numba.core import errors, sigutils, types
1415
from numba.core.compiler import CompileResult, Flags
1516
from numba.core.compiler_lock import global_compiler_lock
1617
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
18+
from numba.core.funcdesc import PythonFunctionDescriptor
1719
from numba.core.target_extension import dispatcher_registry, target_registry
1820
from numba.core.types import void
1921
from numba.core.typing.typeof import Purpose, typeof
2022

2123
from numba_dpex import config, spirv_generator
24+
from numba_dpex.core.codegen import SPIRVCodeLibrary
2225
from numba_dpex.core.exceptions import (
2326
ExecutionQueueInferenceError,
2427
KernelHasReturnValueError,
2528
UnsupportedKernelArgumentError,
2629
)
2730
from numba_dpex.core.pipelines import kernel_compiler
28-
from numba_dpex.core.targets.kernel_target import CompilationMode
31+
from numba_dpex.core.targets.kernel_target import (
32+
CompilationMode,
33+
DpexKernelTargetContext,
34+
)
2935
from numba_dpex.core.types import DpnpNdArray
36+
from numba_dpex.core.utils import kernel_launcher as kl
3037

3138
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
3239

33-
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])
34-
3540
_KernelCompileResult = namedtuple(
3641
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
3742
)
@@ -76,9 +81,14 @@ def check_queue_equivalence_of_args(
7681
)
7782

7883
def _compile_to_spirv(
79-
self, kernel_library, kernel_fndesc, kernel_targetctx
84+
self,
85+
kernel_library: SPIRVCodeLibrary,
86+
kernel_fndesc: PythonFunctionDescriptor,
87+
kernel_targetctx: DpexKernelTargetContext,
8088
):
81-
kernel_func = kernel_library.get_function(kernel_fndesc.llvm_func_name)
89+
kernel_func: ValueRef = kernel_library.get_function(
90+
kernel_fndesc.llvm_func_name
91+
)
8292

8393
# Create a spir_kernel wrapper function
8494
kernel_fn = kernel_targetctx.prepare_spir_kernel(
@@ -103,11 +113,11 @@ def _compile_to_spirv(
103113
kernel_library.final_module,
104114
kernel_library.final_module.as_bitcode(),
105115
)
106-
return _KernelModule(
116+
return kl.SPIRVKernelModule(
107117
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
108118
)
109119

110-
def compile(self, args, return_type):
120+
def compile(self, args, return_type) -> _KernelCompileResult:
111121
status, kcres = self._compile_cached(args, return_type)
112122
if status:
113123
return kcres
@@ -160,8 +170,10 @@ def _compile_cached(
160170
self.targetoptions["_compilation_mode"]
161171
== CompilationMode.KERNEL
162172
):
163-
kernel_device_ir_module: _KernelModule = self._compile_to_spirv(
164-
cres.library, cres.fndesc, cres.target_context
173+
kernel_device_ir_module: kl.SPIRVKernelModule = (
174+
self._compile_to_spirv(
175+
cres.library, cres.fndesc, cres.target_context
176+
)
165177
)
166178
else:
167179
kernel_device_ir_module = None
@@ -329,14 +341,17 @@ def cb_llvm(dur):
329341
# Add code to enable on disk caching of a binary spirv kernel.
330342
# Refer: https://github.com/IntelPython/numba-dpex/issues/1197
331343
self._cache_misses[sig] += 1
332-
ev_details = {
333-
"dispatcher": self,
334-
"args": args,
335-
"return_type": return_type,
336-
}
337-
with ev.trigger_event("numba_dpex:compile", data=ev_details):
344+
with ev.trigger_event(
345+
"numba_dpex:compile",
346+
data={
347+
"dispatcher": self,
348+
"args": args,
349+
"return_type": return_type,
350+
},
351+
):
338352
try:
339-
kcres: _KernelCompileResult = self._compiler.compile(
353+
compiler: _KernelCompiler = self._compiler
354+
kcres: _KernelCompileResult = compiler.compile(
340355
args, return_type
341356
)
342357
except errors.ForceLiteralArg as e:

0 commit comments

Comments
 (0)