Skip to content

Commit 72d8a32

Browse files
authored
Merge pull request #1250 from IntelPython/support_numba_0.59
Changes to KernelDispatcher to support numba 0.59
2 parents ea5b3bd + a20cb0a commit 72d8a32

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from numba.core.types import void
2121
from numba.core.typing.typeof import Purpose, typeof
2222

23-
from numba_dpex import config, spirv_generator
23+
from numba_dpex import config, numba_sem_version, spirv_generator
2424
from numba_dpex.core.codegen import SPIRVCodeLibrary
2525
from numba_dpex.core.exceptions import (
2626
ExecutionQueueInferenceError,
@@ -220,8 +220,6 @@ class KernelDispatcher(Dispatcher):
220220
targetdescr = dpex_exp_kernel_target
221221
_fold_args = False
222222

223-
Dispatcher._impl_kinds["kernel"] = _KernelCompiler
224-
225223
def __init__(
226224
self,
227225
pyfunc,
@@ -240,12 +238,27 @@ def __init__(
240238

241239
self._kernel_name = pyfunc.__name__
242240

243-
super().__init__(
244-
py_func=pyfunc,
245-
locals=local_vars_to_numba_types,
246-
impl_kind="kernel",
247-
targetoptions=targetoptions,
248-
pipeline_class=pipeline_class,
241+
if numba_sem_version < (0, 59, 0):
242+
super().__init__(
243+
py_func=pyfunc,
244+
locals=local_vars_to_numba_types,
245+
impl_kind="direct",
246+
targetoptions=targetoptions,
247+
pipeline_class=pipeline_class,
248+
)
249+
else:
250+
super().__init__(
251+
py_func=pyfunc,
252+
locals=local_vars_to_numba_types,
253+
targetoptions=targetoptions,
254+
pipeline_class=pipeline_class,
255+
)
256+
self._compiler = _KernelCompiler(
257+
pyfunc,
258+
self.targetdescr,
259+
targetoptions,
260+
local_vars_to_numba_types,
261+
pipeline_class,
249262
)
250263

251264
def typeof_pyval(self, val):

0 commit comments

Comments
 (0)