Skip to content

Commit 7d8c135

Browse files
author
Diptorup Deb
committed
Move kernel target to _kernel_api_impl.spirv
- The kernel_target was moved into _kernel_api_impl.spirv. - Renaming of classes to properly indicate that the target is for SPIRV code generation.
1 parent 5635162 commit 7d8c135

File tree

16 files changed

+74
-72
lines changed

16 files changed

+74
-72
lines changed

numba_dpex/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from numba_dpex.core.kernel_interface.launcher import call_kernel
2020

21+
from ._kernel_api_impl.spirv import target as spirv_kernel_target
2122
from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
2223

2324

@@ -107,7 +108,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
107108
# backward compatibility
108109
from numba_dpex.kernel_api import NdRange, Range # noqa E402
109110

110-
from .core.targets import dpjit_target, kernel_target # noqa E402
111+
from .core.targets import dpjit_target # noqa E402
111112
from .decorators import dpjit, func, kernel # noqa E402
112113
from .ocl.stubs import ( # noqa E402
113114
GLOBAL_MEM_FENCE,

numba_dpex/_kernel_api_impl/spirv/dispatcher.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,31 @@
2626
from numba.core.typing.typeof import Purpose, typeof
2727

2828
from numba_dpex import config, numba_sem_version, spirv_generator
29-
from numba_dpex.core.codegen import SPIRVCodeLibrary
29+
from numba_dpex._kernel_api_impl.spirv.codegen import SPIRVCodeLibrary
30+
from numba_dpex._kernel_api_impl.spirv.target import (
31+
CompilationMode,
32+
SPIRVTargetContext,
33+
)
3034
from numba_dpex.core.exceptions import (
3135
ExecutionQueueInferenceError,
3236
InvalidKernelSpecializationError,
3337
KernelHasReturnValueError,
3438
UnsupportedKernelArgumentError,
3539
)
3640
from numba_dpex.core.pipelines import kernel_compiler
37-
from numba_dpex.core.targets.kernel_target import (
38-
CompilationMode,
39-
DpexKernelTargetContext,
40-
)
4141
from numba_dpex.core.types import USMNdArray
4242
from numba_dpex.core.utils import kernel_launcher as kl
4343
from numba_dpex.experimental.target import (
4444
DPEX_KERNEL_EXP_TARGET_NAME,
4545
dpex_exp_kernel_target,
4646
)
4747

48-
_SPVKernelCompileResult = namedtuple(
48+
_SPIRVKernelCompileResult = namedtuple(
4949
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
5050
)
5151

5252

53-
class _SPVKernelCompiler(_FunctionCompiler):
53+
class _SPIRVKernelCompiler(_FunctionCompiler):
5454
"""A special compiler class used to compile numba_dpex.kernel decorated
5555
functions.
5656
"""
@@ -157,7 +157,7 @@ def _compile_to_spirv(
157157
self,
158158
kernel_library: SPIRVCodeLibrary,
159159
kernel_fndesc: PythonFunctionDescriptor,
160-
kernel_targetctx: DpexKernelTargetContext,
160+
kernel_targetctx: SPIRVTargetContext,
161161
):
162162
kernel_func: ValueRef = kernel_library.get_function(
163163
kernel_fndesc.llvm_func_name
@@ -190,7 +190,7 @@ def _compile_to_spirv(
190190
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
191191
)
192192

193-
def compile(self, args, return_type) -> _SPVKernelCompileResult:
193+
def compile(self, args, return_type) -> _SPIRVKernelCompileResult:
194194
status, kcres = self._compile_cached(args, return_type)
195195
if status:
196196
return kcres
@@ -199,7 +199,7 @@ def compile(self, args, return_type) -> _SPVKernelCompileResult:
199199

200200
def _compile_cached(
201201
self, args, return_type: types.Type
202-
) -> Tuple[bool, _SPVKernelCompileResult]:
202+
) -> Tuple[bool, _SPIRVKernelCompileResult]:
203203
"""Compiles the kernel function to bitcode and generates a host-callable
204204
wrapper to submit the kernel to a SYCL queue.
205205
@@ -279,10 +279,10 @@ def _compile_cached(
279279
self._failed_cache[key] = err
280280
return False, err
281281

282-
return True, _SPVKernelCompileResult(*kcres_attrs)
282+
return True, _SPIRVKernelCompileResult(*kcres_attrs)
283283

284284

285-
class SPVKernelDispatcher(Dispatcher):
285+
class SPIRVKernelDispatcher(Dispatcher):
286286
"""Dispatcher class designed to compile kernel decorated functions. The
287287
dispatcher inherits the Numba Dispatcher class, but has a different
288288
compilation strategy. Instead of compiling a kernel decorated function to
@@ -327,7 +327,7 @@ def __init__(
327327
targetoptions=targetoptions,
328328
pipeline_class=pipeline_class,
329329
)
330-
self._compiler = _SPVKernelCompiler(
330+
self._compiler = _SPIRVKernelCompiler(
331331
pyfunc,
332332
self.targetdescr,
333333
targetoptions,
@@ -428,8 +428,8 @@ def cb_llvm(dur):
428428
},
429429
):
430430
try:
431-
compiler: _SPVKernelCompiler = self._compiler
432-
kcres: _SPVKernelCompileResult = compiler.compile(
431+
compiler: _SPIRVKernelCompiler = self._compiler
432+
kcres: _SPIRVKernelCompileResult = compiler.compile(
433433
args, return_type
434434
)
435435
except errors.ForceLiteralArg as err:
@@ -465,4 +465,4 @@ def __call__(self, *args, **kw_args):
465465

466466

467467
_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
468-
dispatcher_registry[_dpex_target] = SPVKernelDispatcher
468+
dispatcher_registry[_dpex_target] = SPIRVKernelDispatcher

numba_dpex/core/targets/kernel_target.py renamed to numba_dpex/_kernel_api_impl/spirv/target.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from numba_dpex.core.utils import get_info_from_suai
2424
from numba_dpex.utils import address_space, calling_conv
2525

26-
from .. import codegen
26+
from . import codegen
2727

2828
CC_SPIR_KERNEL = "spir_kernel"
2929
CC_SPIR_FUNC = "spir_func"
@@ -52,7 +52,7 @@ class CompilationMode(IntEnum):
5252
DEVICE_FUNC = 2
5353

5454

55-
class DpexKernelTypingContext(typing.BaseContext):
55+
class SPIRVTypingContext(typing.BaseContext):
5656
"""Custom typing context to support kernel compilation.
5757
5858
The customized typing context provides two features required to compile
@@ -124,20 +124,20 @@ def load_additional_registries(self):
124124
self.install_registry(enumdecl.registry)
125125

126126

127-
class SyclDevice(GPU):
128-
"""Mark the hardware target as SYCL Device."""
127+
class SPIRVDevice(GPU):
128+
"""Mark the hardware target as device that supports SPIR-V bitcode."""
129129

130130
pass
131131

132132

133-
DPEX_KERNEL_TARGET_NAME = "dpex_kernel"
133+
SPIRV_TARGET_NAME = "spirv"
134134

135-
target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice
135+
target_registry[SPIRV_TARGET_NAME] = SPIRVDevice
136136

137137

138-
class DpexKernelTargetContext(BaseContext):
138+
class SPIRVTargetContext(BaseContext):
139139
"""A target context inheriting Numba's ``BaseContext`` that is customized
140-
for generating SYCL kernels.
140+
for generating SPIR-V kernels.
141141
142142
A customized target context for generating SPIR-V kernels. The class defines
143143
helper functions to generates SPIR-V kernels as LLVM IR using the required
@@ -243,7 +243,7 @@ def _generate_spir_kernel_wrapper(self, func, argtypes):
243243
module.get_function(func.name).linkage = "internal"
244244
return wrapper
245245

246-
def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
246+
def __init__(self, typingctx, target=SPIRV_TARGET_NAME):
247247
super().__init__(typingctx, target)
248248

249249
def init(self):
@@ -338,7 +338,7 @@ def load_additional_registries(self):
338338

339339
@cached_property
340340
def call_conv(self):
341-
return DpexCallConv(self)
341+
return SPIRVCallConv(self)
342342

343343
def codegen(self):
344344
return self._internal_codegen
@@ -385,9 +385,7 @@ def declare_function(self, module, fndesc):
385385
)
386386
if not self.enable_debuginfo:
387387
fn.attributes.add("alwaysinline")
388-
ret = super(DpexKernelTargetContext, self).declare_function(
389-
module, fndesc
390-
)
388+
ret = super(SPIRVTargetContext, self).declare_function(module, fndesc)
391389
ret.calling_convention = calling_conv.CC_SPIR_FUNC
392390
return ret
393391

@@ -444,12 +442,12 @@ def populate_array(self, arr, **kwargs):
444442
return arrayobj.populate_array(arr, **kwargs)
445443

446444

447-
class DpexCallConv(MinimalCallConv):
445+
class SPIRVCallConv(MinimalCallConv):
448446
"""Custom calling convention class used by numba-dpex.
449447
450448
numba_dpex's calling convention derives from
451449
:class:`numba.core.callconv import MinimalCallConv`. The
452-
:class:`DpexCallConv` overrides :func:`call_function`.
450+
:class:`SPIRVCallConv` overrides :func:`call_function`.
453451
454452
"""
455453

numba_dpex/core/descriptor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
from numba.core.cpu import CPUTargetOptions
99
from numba.core.descriptors import TargetDescriptor
1010

11+
from numba_dpex._kernel_api_impl.spirv.target import (
12+
SPIRV_TARGET_NAME,
13+
CompilationMode,
14+
SPIRVTargetContext,
15+
SPIRVTypingContext,
16+
)
1117
from numba_dpex.core import config
1218

1319
from .targets.dpjit_target import (
1420
DPEX_TARGET_NAME,
1521
DpexTargetContext,
1622
DpexTypingContext,
1723
)
18-
from .targets.kernel_target import (
19-
DPEX_KERNEL_TARGET_NAME,
20-
CompilationMode,
21-
DpexKernelTargetContext,
22-
DpexKernelTypingContext,
23-
)
2424

2525
_option_mapping = options._mapping
2626

@@ -77,12 +77,12 @@ class DpexKernelTarget(TargetDescriptor):
7777
@cached_property
7878
def _toplevel_target_context(self):
7979
"""Lazily-initialized top-level target context, for all threads."""
80-
return DpexKernelTargetContext(self.typing_context, self._target_name)
80+
return SPIRVTargetContext(self.typing_context, self._target_name)
8181

8282
@cached_property
8383
def _toplevel_typing_context(self):
8484
"""Lazily-initialized top-level typing context, for all threads."""
85-
return DpexKernelTypingContext()
85+
return SPIRVTypingContext()
8686

8787
@property
8888
def target_context(self):
@@ -132,7 +132,7 @@ def typing_context(self):
132132

133133

134134
# A global instance of the DpexKernelTarget
135-
dpex_kernel_target = DpexKernelTarget(DPEX_KERNEL_TARGET_NAME)
135+
dpex_kernel_target = DpexKernelTarget(SPIRV_TARGET_NAME)
136136

137137
# A global instance of the DpexTarget
138138
dpex_target = DpexTarget(DPEX_TARGET_NAME)

numba_dpex/core/kernel_interface/spirv_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from numba.core import ir
99

1010
from numba_dpex import spirv_generator
11+
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
1112
from numba_dpex.core import config
1213
from numba_dpex.core.compiler import compile_with_dpex
1314
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
14-
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
1515

1616
from .kernel_base import KernelInterface
1717

@@ -135,7 +135,7 @@ def compile(
135135
)
136136

137137
func = cres.library.get_function(cres.fndesc.llvm_func_name)
138-
kernel_targetctx: DpexKernelTargetContext = cres.target_context
138+
kernel_targetctx: SPIRVTargetContext = cres.target_context
139139
kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)
140140

141141
# XXX: Setting the inline_threshold in the following way is a temporary

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from numba.np.arrayobj import make_array
1717
from numba.np.numpy_support import is_nonelike
1818

19+
from numba_dpex._kernel_api_impl.spirv.target import SPIRVTargetContext
1920
from numba_dpex.core.kernel_interface.arrayobj import (
2021
_getitem_array_generic as kernel_getitem_array_generic,
2122
)
22-
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
2323
from numba_dpex.core.types import DpnpNdArray
2424

2525
from ._intrinsic import (
@@ -1082,7 +1082,7 @@ def getitem_arraynd_intp(context, builder, sig, args):
10821082
that when returning a view of a dpnp.ndarray the sycl::queue pointer
10831083
member in the LLVM IR struct gets properly updated.
10841084
"""
1085-
getitem_call_in_kernel = isinstance(context, DpexKernelTargetContext)
1085+
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
10861086
_getitem_array_generic = np_getitem_array_generic
10871087

10881088
if getitem_call_in_kernel:

numba_dpex/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from numba.core.imputils import Registry
1010

11-
from numba_dpex._kernel_api_impl.spirv.dispatcher import SPVKernelDispatcher
11+
from numba_dpex._kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
1212

1313
# Temporary so that Range and NdRange work in experimental call_kernel
1414
from numba_dpex.core.boxing import *
@@ -42,5 +42,5 @@ def dpex_dispatcher_const(context):
4242
"call_kernel",
4343
"call_kernel_async",
4444
"IntEnumLiteral",
45-
"SPVKernelDispatcher",
45+
"SPIRVKernelDispatcher",
4646
]

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from numba.core import cgutils, types
1212
from numba.extending import intrinsic, overload, overload_method
1313

14+
from numba_dpex._kernel_api_impl.spirv.target import (
15+
CC_SPIR_FUNC,
16+
LLVM_SPIRV_ARGS,
17+
)
1418
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
15-
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC, LLVM_SPIRV_ARGS
1619
from numba_dpex.core.types import USMNdArray
1720
from numba_dpex.kernel_api import (
1821
AddressSpace,

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_fn_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from llvmlite import ir as llvmir
1111
from numba.core import cgutils, types
1212

13+
from numba_dpex._kernel_api_impl.spirv.target import CC_SPIR_FUNC
1314
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
14-
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC
1515

1616

1717
def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):

0 commit comments

Comments
 (0)