10
10
from typing import Tuple
11
11
12
12
import numba .core .event as ev
13
+ from llvmlite .binding .value import ValueRef
13
14
from numba .core import errors , sigutils , types
14
15
from numba .core .compiler import CompileResult , Flags
15
16
from numba .core .compiler_lock import global_compiler_lock
16
17
from numba .core .dispatcher import Dispatcher , _FunctionCompiler
18
+ from numba .core .funcdesc import PythonFunctionDescriptor
17
19
from numba .core .target_extension import dispatcher_registry , target_registry
18
20
from numba .core .types import void
19
21
from numba .core .typing .typeof import Purpose , typeof
20
22
21
23
from numba_dpex import config , spirv_generator
24
+ from numba_dpex .core .codegen import SPIRVCodeLibrary
22
25
from numba_dpex .core .exceptions import (
23
26
ExecutionQueueInferenceError ,
24
27
KernelHasReturnValueError ,
25
28
UnsupportedKernelArgumentError ,
26
29
)
27
30
from numba_dpex .core .pipelines import kernel_compiler
28
- from numba_dpex .core .targets .kernel_target import CompilationMode
31
+ from numba_dpex .core .targets .kernel_target import (
32
+ CompilationMode ,
33
+ DpexKernelTargetContext ,
34
+ )
29
35
from numba_dpex .core .types import DpnpNdArray
36
+ from numba_dpex .core .utils import kernel_launcher as kl
30
37
31
38
from .target import DPEX_KERNEL_EXP_TARGET_NAME , dpex_exp_kernel_target
32
39
33
- _KernelModule = namedtuple ("_KernelModule" , ["kernel_name" , "kernel_bitcode" ])
34
-
35
40
_KernelCompileResult = namedtuple (
36
41
"_KernelCompileResult" , CompileResult ._fields + ("kernel_device_ir_module" ,)
37
42
)
@@ -76,9 +81,14 @@ def check_queue_equivalence_of_args(
76
81
)
77
82
78
83
def _compile_to_spirv (
79
- self , kernel_library , kernel_fndesc , kernel_targetctx
84
+ self ,
85
+ kernel_library : SPIRVCodeLibrary ,
86
+ kernel_fndesc : PythonFunctionDescriptor ,
87
+ kernel_targetctx : DpexKernelTargetContext ,
80
88
):
81
- kernel_func = kernel_library .get_function (kernel_fndesc .llvm_func_name )
89
+ kernel_func : ValueRef = kernel_library .get_function (
90
+ kernel_fndesc .llvm_func_name
91
+ )
82
92
83
93
# Create a spir_kernel wrapper function
84
94
kernel_fn = kernel_targetctx .prepare_spir_kernel (
@@ -103,11 +113,11 @@ def _compile_to_spirv(
103
113
kernel_library .final_module ,
104
114
kernel_library .final_module .as_bitcode (),
105
115
)
106
- return _KernelModule (
116
+ return kl . SPIRVKernelModule (
107
117
kernel_name = kernel_fn .name , kernel_bitcode = kernel_spirv_module
108
118
)
109
119
110
- def compile (self , args , return_type ):
120
+ def compile (self , args , return_type ) -> _KernelCompileResult :
111
121
status , kcres = self ._compile_cached (args , return_type )
112
122
if status :
113
123
return kcres
@@ -160,8 +170,10 @@ def _compile_cached(
160
170
self .targetoptions ["_compilation_mode" ]
161
171
== CompilationMode .KERNEL
162
172
):
163
- kernel_device_ir_module : _KernelModule = self ._compile_to_spirv (
164
- cres .library , cres .fndesc , cres .target_context
173
+ kernel_device_ir_module : kl .SPIRVKernelModule = (
174
+ self ._compile_to_spirv (
175
+ cres .library , cres .fndesc , cres .target_context
176
+ )
165
177
)
166
178
else :
167
179
kernel_device_ir_module = None
@@ -329,14 +341,17 @@ def cb_llvm(dur):
329
341
# Add code to enable on disk caching of a binary spirv kernel.
330
342
# Refer: https://github.com/IntelPython/numba-dpex/issues/1197
331
343
self ._cache_misses [sig ] += 1
332
- ev_details = {
333
- "dispatcher" : self ,
334
- "args" : args ,
335
- "return_type" : return_type ,
336
- }
337
- with ev .trigger_event ("numba_dpex:compile" , data = ev_details ):
344
+ with ev .trigger_event (
345
+ "numba_dpex:compile" ,
346
+ data = {
347
+ "dispatcher" : self ,
348
+ "args" : args ,
349
+ "return_type" : return_type ,
350
+ },
351
+ ):
338
352
try :
339
- kcres : _KernelCompileResult = self ._compiler .compile (
353
+ compiler : _KernelCompiler = self ._compiler
354
+ kcres : _KernelCompileResult = compiler .compile (
340
355
args , return_type
341
356
)
342
357
except errors .ForceLiteralArg as e :
0 commit comments