Skip to content

Commit df32f71

Browse files
author
Diptorup Deb
authored
Merge pull request #1314 from IntelPython/refactor/kernel_api_impl
A new _kernel_api_impl module
2 parents b494492 + 7d8c135 commit df32f71

File tree

18 files changed

+94
-76
lines changed

18 files changed

+94
-76
lines changed

numba_dpex/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from numba_dpex.core.kernel_interface.launcher import call_kernel
2020

21+
from ._kernel_api_impl.spirv import target as spirv_kernel_target
2122
from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
2223

2324

@@ -107,7 +108,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
107108
# backward compatibility
108109
from numba_dpex.kernel_api import NdRange, Range # noqa E402
109110

110-
from .core.targets import dpjit_target, kernel_target # noqa E402
111+
from .core.targets import dpjit_target # noqa E402
111112
from .decorators import dpjit, func, kernel # noqa E402
112113
from .ocl.stubs import ( # noqa E402
113114
GLOBAL_MEM_FENCE,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""The module stores the numba_dpex backends implementing the target-specific
6+
code generation for the kernel_api Python functions.
7+
"""
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""A SPIR-V backend to compile the numba_dpex.kernel_api functions to SPIR-V.
6+
"""

numba_dpex/experimental/kernel_dispatcher.py renamed to numba_dpex/_kernel_api_impl/spirv/dispatcher.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,31 @@
2626
from numba.core.typing.typeof import Purpose, typeof
2727

2828
from numba_dpex import config, numba_sem_version, spirv_generator
29-
from numba_dpex.core.codegen import SPIRVCodeLibrary
29+
from numba_dpex._kernel_api_impl.spirv.codegen import SPIRVCodeLibrary
30+
from numba_dpex._kernel_api_impl.spirv.target import (
31+
CompilationMode,
32+
SPIRVTargetContext,
33+
)
3034
from numba_dpex.core.exceptions import (
3135
ExecutionQueueInferenceError,
3236
InvalidKernelSpecializationError,
3337
KernelHasReturnValueError,
3438
UnsupportedKernelArgumentError,
3539
)
3640
from numba_dpex.core.pipelines import kernel_compiler
37-
from numba_dpex.core.targets.kernel_target import (
38-
CompilationMode,
39-
DpexKernelTargetContext,
40-
)
4141
from numba_dpex.core.types import USMNdArray
4242
from numba_dpex.core.utils import kernel_launcher as kl
43+
from numba_dpex.experimental.target import (
44+
DPEX_KERNEL_EXP_TARGET_NAME,
45+
dpex_exp_kernel_target,
46+
)
4347

44-
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
45-
46-
_KernelCompileResult = namedtuple(
48+
_SPIRVKernelCompileResult = namedtuple(
4749
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
4850
)
4951

5052

51-
class _KernelCompiler(_FunctionCompiler):
53+
class _SPIRVKernelCompiler(_FunctionCompiler):
5254
"""A special compiler class used to compile numba_dpex.kernel decorated
5355
functions.
5456
"""
@@ -155,7 +157,7 @@ def _compile_to_spirv(
155157
self,
156158
kernel_library: SPIRVCodeLibrary,
157159
kernel_fndesc: PythonFunctionDescriptor,
158-
kernel_targetctx: DpexKernelTargetContext,
160+
kernel_targetctx: SPIRVTargetContext,
159161
):
160162
kernel_func: ValueRef = kernel_library.get_function(
161163
kernel_fndesc.llvm_func_name
@@ -188,7 +190,7 @@ def _compile_to_spirv(
188190
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
189191
)
190192

191-
def compile(self, args, return_type) -> _KernelCompileResult:
193+
def compile(self, args, return_type) -> _SPIRVKernelCompileResult:
192194
status, kcres = self._compile_cached(args, return_type)
193195
if status:
194196
return kcres
@@ -197,7 +199,7 @@ def compile(self, args, return_type) -> _KernelCompileResult:
197199

198200
def _compile_cached(
199201
self, args, return_type: types.Type
200-
) -> Tuple[bool, _KernelCompileResult]:
202+
) -> Tuple[bool, _SPIRVKernelCompileResult]:
201203
"""Compiles the kernel function to bitcode and generates a host-callable
202204
wrapper to submit the kernel to a SYCL queue.
203205
@@ -277,10 +279,10 @@ def _compile_cached(
277279
self._failed_cache[key] = err
278280
return False, err
279281

280-
return True, _KernelCompileResult(*kcres_attrs)
282+
return True, _SPIRVKernelCompileResult(*kcres_attrs)
281283

282284

283-
class KernelDispatcher(Dispatcher):
285+
class SPIRVKernelDispatcher(Dispatcher):
284286
"""Dispatcher class designed to compile kernel decorated functions. The
285287
dispatcher inherits the Numba Dispatcher class, but has a different
286288
compilation strategy. Instead of compiling a kernel decorated function to
@@ -325,7 +327,7 @@ def __init__(
325327
targetoptions=targetoptions,
326328
pipeline_class=pipeline_class,
327329
)
328-
self._compiler = _KernelCompiler(
330+
self._compiler = _SPIRVKernelCompiler(
329331
pyfunc,
330332
self.targetdescr,
331333
targetoptions,
@@ -426,8 +428,8 @@ def cb_llvm(dur):
426428
},
427429
):
428430
try:
429-
compiler: _KernelCompiler = self._compiler
430-
kcres: _KernelCompileResult = compiler.compile(
431+
compiler: _SPIRVKernelCompiler = self._compiler
432+
kcres: _SPIRVKernelCompileResult = compiler.compile(
431433
args, return_type
432434
)
433435
except errors.ForceLiteralArg as err:
@@ -463,4 +465,4 @@ def __call__(self, *args, **kw_args):
463465

464466

465467
_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
466-
dispatcher_registry[_dpex_target] = KernelDispatcher
468+
dispatcher_registry[_dpex_target] = SPIRVKernelDispatcher

numba_dpex/core/targets/kernel_target.py renamed to numba_dpex/_kernel_api_impl/spirv/target.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from numba_dpex.core.utils import get_info_from_suai
2424
from numba_dpex.utils import address_space, calling_conv
2525

26-
from .. import codegen
26+
from . import codegen
2727

2828
CC_SPIR_KERNEL = "spir_kernel"
2929
CC_SPIR_FUNC = "spir_func"
@@ -52,7 +52,7 @@ class CompilationMode(IntEnum):
5252
DEVICE_FUNC = 2
5353

5454

55-
class DpexKernelTypingContext(typing.BaseContext):
55+
class SPIRVTypingContext(typing.BaseContext):
5656
"""Custom typing context to support kernel compilation.
5757
5858
The customized typing context provides two features required to compile
@@ -124,20 +124,20 @@ def load_additional_registries(self):
124124
self.install_registry(enumdecl.registry)
125125

126126

127-
class SyclDevice(GPU):
128-
"""Mark the hardware target as SYCL Device."""
127+
class SPIRVDevice(GPU):
128+
"""Mark the hardware target as device that supports SPIR-V bitcode."""
129129

130130
pass
131131

132132

133-
DPEX_KERNEL_TARGET_NAME = "dpex_kernel"
133+
SPIRV_TARGET_NAME = "spirv"
134134

135-
target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice
135+
target_registry[SPIRV_TARGET_NAME] = SPIRVDevice
136136

137137

138-
class DpexKernelTargetContext(BaseContext):
138+
class SPIRVTargetContext(BaseContext):
139139
"""A target context inheriting Numba's ``BaseContext`` that is customized
140-
for generating SYCL kernels.
140+
for generating SPIR-V kernels.
141141
142142
A customized target context for generating SPIR-V kernels. The class defines
143143
helper functions to generates SPIR-V kernels as LLVM IR using the required
@@ -243,7 +243,7 @@ def _generate_spir_kernel_wrapper(self, func, argtypes):
243243
module.get_function(func.name).linkage = "internal"
244244
return wrapper
245245

246-
def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
246+
def __init__(self, typingctx, target=SPIRV_TARGET_NAME):
247247
super().__init__(typingctx, target)
248248

249249
def init(self):
@@ -338,7 +338,7 @@ def load_additional_registries(self):
338338

339339
@cached_property
340340
def call_conv(self):
341-
return DpexCallConv(self)
341+
return SPIRVCallConv(self)
342342

343343
def codegen(self):
344344
return self._internal_codegen
@@ -385,9 +385,7 @@ def declare_function(self, module, fndesc):
385385
)
386386
if not self.enable_debuginfo:
387387
fn.attributes.add("alwaysinline")
388-
ret = super(DpexKernelTargetContext, self).declare_function(
389-
module, fndesc
390-
)
388+
ret = super(SPIRVTargetContext, self).declare_function(module, fndesc)
391389
ret.calling_convention = calling_conv.CC_SPIR_FUNC
392390
return ret
393391

@@ -444,12 +442,12 @@ def populate_array(self, arr, **kwargs):
444442
return arrayobj.populate_array(arr, **kwargs)
445443

446444

447-
class DpexCallConv(MinimalCallConv):
445+
class SPIRVCallConv(MinimalCallConv):
448446
"""Custom calling convention class used by numba-dpex.
449447
450448
numba_dpex's calling convention derives from
451449
:class:`numba.core.callconv import MinimalCallConv`. The
452-
:class:`DpexCallConv` overrides :func:`call_function`.
450+
:class:`SPIRVCallConv` overrides :func:`call_function`.
453451
454452
"""
455453

numba_dpex/core/descriptor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
from numba.core.cpu import CPUTargetOptions
99
from numba.core.descriptors import TargetDescriptor
1010

11+
from numba_dpex._kernel_api_impl.spirv.target import (
12+
SPIRV_TARGET_NAME,
13+
CompilationMode,
14+
SPIRVTargetContext,
15+
SPIRVTypingContext,
16+
)
1117
from numba_dpex.core import config
1218

1319
from .targets.dpjit_target import (
1420
DPEX_TARGET_NAME,
1521
DpexTargetContext,
1622
DpexTypingContext,
1723
)
18-
from .targets.kernel_target import (
19-
DPEX_KERNEL_TARGET_NAME,
20-
CompilationMode,
21-
DpexKernelTargetContext,
22-
DpexKernelTypingContext,
23-
)
2424

2525
_option_mapping = options._mapping
2626

@@ -77,12 +77,12 @@ class DpexKernelTarget(TargetDescriptor):
7777
@cached_property
7878
def _toplevel_target_context(self):
7979
"""Lazily-initialized top-level target context, for all threads."""
80-
return DpexKernelTargetContext(self.typing_context, self._target_name)
80+
return SPIRVTargetContext(self.typing_context, self._target_name)
8181

8282
@cached_property
8383
def _toplevel_typing_context(self):
8484
"""Lazily-initialized top-level typing context, for all threads."""
85-
return DpexKernelTypingContext()
85+
return SPIRVTypingContext()
8686

8787
@property
8888
def target_context(self):
@@ -132,7 +132,7 @@ def typing_context(self):
132132

133133

134134
# A global instance of the DpexKernelTarget
135-
dpex_kernel_target = DpexKernelTarget(DPEX_KERNEL_TARGET_NAME)
135+
dpex_kernel_target = DpexKernelTarget(SPIRV_TARGET_NAME)
136136

137137
# A global instance of the DpexTarget
138138
dpex_target = DpexTarget(DPEX_TARGET_NAME)

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from numba.core import ir
99

1010
from numba_dpex import spirv_generator
11+
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
1112
from numba_dpex.core import config
1213
from numba_dpex.core.compiler import compile_with_dpex
1314
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
14-
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
1515

1616
from .kernel_base import KernelInterface
1717

@@ -135,7 +135,7 @@ def compile(
135135
)
136136

137137
func = cres.library.get_function(cres.fndesc.llvm_func_name)
138-
kernel_targetctx: DpexKernelTargetContext = cres.target_context
138+
kernel_targetctx: SPIRVTargetContext = cres.target_context
139139
kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)
140140

141141
# XXX: Setting the inline_threshold in the following way is a temporary

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from numba.np.arrayobj import make_array
1717
from numba.np.numpy_support import is_nonelike
1818

19+
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
1920
from numba_dpex.core.kernel_interface.arrayobj import (
2021
_getitem_array_generic as kernel_getitem_array_generic,
2122
)
22-
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
2323
from numba_dpex.core.types import DpnpNdArray
2424

2525
from ._intrinsic import (
@@ -1082,7 +1082,7 @@ def getitem_arraynd_intp(context, builder, sig, args):
10821082
that when returning a view of a dpnp.ndarray the sycl::queue pointer
10831083
member in the LLVM IR struct gets properly updated.
10841084
"""
1085-
getitem_call_in_kernel = isinstance(context, DpexKernelTargetContext)
1085+
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
10861086
_getitem_array_generic = np_getitem_array_generic
10871087

10881088
if getitem_call_in_kernel:

numba_dpex/experimental/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from numba.core.imputils import Registry
1010

11+
from numba_dpex._kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
12+
1113
# Temporary so that Range and NdRange work in experimental call_kernel
1214
from numba_dpex.core.boxing import *
1315

@@ -17,7 +19,6 @@
1719
_index_space_id_overloads,
1820
)
1921
from .decorators import device_func, kernel
20-
from .kernel_dispatcher import KernelDispatcher
2122
from .launcher import call_kernel, call_kernel_async
2223
from .literal_intenum_type import IntEnumLiteral
2324
from .models import *
@@ -41,5 +42,5 @@ def dpex_dispatcher_const(context):
4142
"call_kernel",
4243
"call_kernel_async",
4344
"IntEnumLiteral",
44-
"KernelDispatcher",
45+
"SPIRVKernelDispatcher",
4546
]

0 commit comments

Comments
 (0)