Skip to content

Commit 5635162

Browse files
author
Diptorup Deb
committed
Migrate the kernel_dispatcher into _kernel_api_impl.spirv
1 parent b494492 commit 5635162

File tree

6 files changed

+38
-22
lines changed

6 files changed

+38
-22
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""The module stores the numba_dpex backends implementing the target-specific
6+
code generation for the kernel_api Python functions.
7+
"""
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""A SPIR-V backend to compile the numba_dpex.kernel_api functions to SPIR-V.
6+
"""

numba_dpex/experimental/kernel_dispatcher.py renamed to numba_dpex/_kernel_api_impl/spirv/dispatcher.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@
4040
)
4141
from numba_dpex.core.types import USMNdArray
4242
from numba_dpex.core.utils import kernel_launcher as kl
43+
from numba_dpex.experimental.target import (
44+
DPEX_KERNEL_EXP_TARGET_NAME,
45+
dpex_exp_kernel_target,
46+
)
4347

44-
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
45-
46-
_KernelCompileResult = namedtuple(
48+
_SPVKernelCompileResult = namedtuple(
4749
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
4850
)
4951

5052

51-
class _KernelCompiler(_FunctionCompiler):
53+
class _SPVKernelCompiler(_FunctionCompiler):
5254
"""A special compiler class used to compile numba_dpex.kernel decorated
5355
functions.
5456
"""
@@ -188,7 +190,7 @@ def _compile_to_spirv(
188190
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
189191
)
190192

191-
def compile(self, args, return_type) -> _KernelCompileResult:
193+
def compile(self, args, return_type) -> _SPVKernelCompileResult:
192194
status, kcres = self._compile_cached(args, return_type)
193195
if status:
194196
return kcres
@@ -197,7 +199,7 @@ def compile(self, args, return_type) -> _KernelCompileResult:
197199

198200
def _compile_cached(
199201
self, args, return_type: types.Type
200-
) -> Tuple[bool, _KernelCompileResult]:
202+
) -> Tuple[bool, _SPVKernelCompileResult]:
201203
"""Compiles the kernel function to bitcode and generates a host-callable
202204
wrapper to submit the kernel to a SYCL queue.
203205
@@ -277,10 +279,10 @@ def _compile_cached(
277279
self._failed_cache[key] = err
278280
return False, err
279281

280-
return True, _KernelCompileResult(*kcres_attrs)
282+
return True, _SPVKernelCompileResult(*kcres_attrs)
281283

282284

283-
class KernelDispatcher(Dispatcher):
285+
class SPVKernelDispatcher(Dispatcher):
284286
"""Dispatcher class designed to compile kernel decorated functions. The
285287
dispatcher inherits the Numba Dispatcher class, but has a different
286288
compilation strategy. Instead of compiling a kernel decorated function to
@@ -325,7 +327,7 @@ def __init__(
325327
targetoptions=targetoptions,
326328
pipeline_class=pipeline_class,
327329
)
328-
self._compiler = _KernelCompiler(
330+
self._compiler = _SPVKernelCompiler(
329331
pyfunc,
330332
self.targetdescr,
331333
targetoptions,
@@ -426,8 +428,8 @@ def cb_llvm(dur):
426428
},
427429
):
428430
try:
429-
compiler: _KernelCompiler = self._compiler
430-
kcres: _KernelCompileResult = compiler.compile(
431+
compiler: _SPVKernelCompiler = self._compiler
432+
kcres: _SPVKernelCompileResult = compiler.compile(
431433
args, return_type
432434
)
433435
except errors.ForceLiteralArg as err:
@@ -463,4 +465,4 @@ def __call__(self, *args, **kw_args):
463465

464466

465467
_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
466-
dispatcher_registry[_dpex_target] = KernelDispatcher
468+
dispatcher_registry[_dpex_target] = SPVKernelDispatcher

numba_dpex/experimental/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from numba.core.imputils import Registry
1010

11+
from numba_dpex._kernel_api_impl.spirv.dispatcher import SPVKernelDispatcher
12+
1113
# Temporary so that Range and NdRange work in experimental call_kernel
1214
from numba_dpex.core.boxing import *
1315

@@ -17,7 +19,6 @@
1719
_index_space_id_overloads,
1820
)
1921
from .decorators import device_func, kernel
20-
from .kernel_dispatcher import KernelDispatcher
2122
from .launcher import call_kernel, call_kernel_async
2223
from .literal_intenum_type import IntEnumLiteral
2324
from .models import *
@@ -41,5 +42,5 @@ def dpex_dispatcher_const(context):
4142
"call_kernel",
4243
"call_kernel_async",
4344
"IntEnumLiteral",
44-
"KernelDispatcher",
45+
"SPVKernelDispatcher",
4546
]

numba_dpex/experimental/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
target_registry,
1616
)
1717

18+
from numba_dpex._kernel_api_impl.spirv.dispatcher import SPVKernelDispatcher
1819
from numba_dpex.core.targets.kernel_target import CompilationMode
19-
from numba_dpex.experimental.kernel_dispatcher import KernelDispatcher
2020

2121
from .target import DPEX_KERNEL_EXP_TARGET_NAME
2222

@@ -78,7 +78,7 @@ def kernel(func_or_sig=None, **options):
7878
)
7979

8080
def _kernel_dispatcher(pyfunc):
81-
disp: KernelDispatcher = dispatcher(
81+
disp: SPVKernelDispatcher = dispatcher(
8282
pyfunc=pyfunc,
8383
targetoptions=options,
8484
)

numba_dpex/experimental/launcher.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from numba.extending import intrinsic
2020

2121
from numba_dpex import dpjit
22+
from numba_dpex._kernel_api_impl.spirv.dispatcher import (
23+
SPVKernelDispatcher,
24+
_SPVKernelCompileResult,
25+
)
2226
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
2327
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
2428
from numba_dpex.core.types import DpctlSyclEvent, NdRangeType, RangeType
@@ -29,10 +33,6 @@
2933
ItemType,
3034
NdItemType,
3135
)
32-
from numba_dpex.experimental.kernel_dispatcher import (
33-
KernelDispatcher,
34-
_KernelCompileResult,
35-
)
3636

3737

3838
class LLRange(NamedTuple):
@@ -156,8 +156,8 @@ def _submit_kernel( # pylint: disable=too-many-arguments
156156
# ty_kernel_fn is type specific to exact function, so we can get function
157157
# directly from type and compile it. Thats why we don't need to get it in
158158
# codegen
159-
kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher
160-
kcres: _KernelCompileResult = kernel_dispatcher.get_compile_result(
159+
kernel_dispatcher: SPVKernelDispatcher = ty_kernel_fn.dispatcher
160+
kcres: _SPVKernelCompileResult = kernel_dispatcher.get_compile_result(
161161
types.void(*ty_kernel_args_tuple) # kernel signature
162162
)
163163
kernel_module: kl.SPIRVKernelModule = kcres.kernel_device_ir_module

0 commit comments

Comments
 (0)