20
20
from numba .core .types import void
21
21
from numba .core .typing .typeof import Purpose , typeof
22
22
23
- from numba_dpex import config , spirv_generator
23
+ from numba_dpex import config , numba_sem_version , spirv_generator
24
24
from numba_dpex .core .codegen import SPIRVCodeLibrary
25
25
from numba_dpex .core .exceptions import (
26
26
ExecutionQueueInferenceError ,
@@ -220,8 +220,6 @@ class KernelDispatcher(Dispatcher):
220
220
targetdescr = dpex_exp_kernel_target
221
221
_fold_args = False
222
222
223
- Dispatcher ._impl_kinds ["kernel" ] = _KernelCompiler
224
-
225
223
def __init__ (
226
224
self ,
227
225
pyfunc ,
@@ -240,12 +238,27 @@ def __init__(
240
238
241
239
self ._kernel_name = pyfunc .__name__
242
240
243
- super ().__init__ (
244
- py_func = pyfunc ,
245
- locals = local_vars_to_numba_types ,
246
- impl_kind = "kernel" ,
247
- targetoptions = targetoptions ,
248
- pipeline_class = pipeline_class ,
241
+ if numba_sem_version < (0 , 59 , 0 ):
242
+ super ().__init__ (
243
+ py_func = pyfunc ,
244
+ locals = local_vars_to_numba_types ,
245
+ impl_kind = "direct" ,
246
+ targetoptions = targetoptions ,
247
+ pipeline_class = pipeline_class ,
248
+ )
249
+ else :
250
+ super ().__init__ (
251
+ py_func = pyfunc ,
252
+ locals = local_vars_to_numba_types ,
253
+ targetoptions = targetoptions ,
254
+ pipeline_class = pipeline_class ,
255
+ )
256
+ self ._compiler = _KernelCompiler (
257
+ pyfunc ,
258
+ self .targetdescr ,
259
+ targetoptions ,
260
+ local_vars_to_numba_types ,
261
+ pipeline_class ,
249
262
)
250
263
251
264
def typeof_pyval (self , val ):
0 commit comments