Skip to content

Commit 5f23bc5

Browse files
author
Diptorup Deb
authored
Merge pull request #1225 from IntelPython/experimental/use_standalone_datamodel_manager
Register the experimental kernel target as a fully standalone Numba hardware target
2 parents ca8e769 + 529d33c commit 5f23bc5

File tree

5 files changed

+117
-48
lines changed

5 files changed

+117
-48
lines changed

numba_dpex/experimental/decorators.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import inspect
99

1010
from numba.core import sigutils
11-
from numba.core.target_extension import jit_registry, target_registry
11+
from numba.core.target_extension import (
12+
jit_registry,
13+
resolve_dispatcher_from_str,
14+
target_registry,
15+
)
1216

13-
from .kernel_dispatcher import KernelDispatcher
17+
from .target import DPEX_KERNEL_EXP_TARGET_NAME
1418

1519

1620
def kernel(func_or_sig=None, **options):
@@ -24,11 +28,14 @@ def kernel(func_or_sig=None, **options):
2428
* All array arguments passed to a kernel should adhere to compute
2529
follows data programming model.
2630
"""
31+
32+
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
33+
2734
# FIXME: The options need to be evaluated and checked here like it is
2835
# done in numba.core.decorators.jit
2936

3037
def _kernel_dispatcher(pyfunc):
31-
return KernelDispatcher(
38+
return dispatcher(
3239
pyfunc=pyfunc,
3340
targetoptions=options,
3441
)
@@ -59,9 +66,7 @@ def _kernel_dispatcher(pyfunc):
5966
func_or_sig = [func_or_sig]
6067

6168
def _specialized_kernel_dispatcher(pyfunc):
62-
return KernelDispatcher(
63-
pyfunc=pyfunc,
64-
)
69+
return dispatcher(pyfunc=pyfunc)
6570

6671
return _specialized_kernel_dispatcher
6772
func = func_or_sig
@@ -75,4 +80,4 @@ def _specialized_kernel_dispatcher(pyfunc):
7580
return _kernel_dispatcher(func)
7681

7782

78-
jit_registry[target_registry["dpex_kernel"]] = kernel
83+
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = kernel

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
from collections import namedtuple
99
from contextlib import ExitStack
10+
from typing import Tuple
1011

1112
import numba.core.event as ev
1213
from numba.core import errors, sigutils, types
@@ -24,13 +25,12 @@
2425
from numba_dpex.core.pipelines import kernel_compiler
2526
from numba_dpex.core.types import DpnpNdArray
2627

27-
from .target import dpex_exp_kernel_target
28+
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
2829

2930
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])
3031

3132
_KernelCompileResult = namedtuple(
32-
"_KernelCompileResult",
33-
["status", "cres_or_error", "entry_point"],
33+
"_KernelCompileResult", CompileResult._fields + ("kernel_device_ir_module",)
3434
)
3535

3636

@@ -96,15 +96,15 @@ def _compile_to_spirv(
9696
)
9797

9898
def compile(self, args, return_type):
99-
kcres = self._compile_cached(args, return_type)
100-
if kcres.status:
99+
status, kcres = self._compile_cached(args, return_type)
100+
if status:
101101
return kcres
102102

103-
raise kcres.cres_or_error
103+
raise kcres
104104

105105
def _compile_cached(
106106
self, args, return_type: types.Type
107-
) -> _KernelCompileResult:
107+
) -> Tuple[bool, _KernelCompileResult]:
108108
"""Compiles the kernel function to bitcode and generates a host-callable
109109
wrapper to submit the kernel to a SYCL queue.
110110
@@ -137,34 +137,45 @@ def _compile_cached(
137137
"""
138138
key = tuple(args), return_type
139139
try:
140-
return _KernelCompileResult(False, self._failed_cache[key], None)
140+
return False, self._failed_cache[key]
141141
except KeyError:
142142
pass
143143

144144
try:
145-
kernel_cres: CompileResult = self._compile_core(args, return_type)
145+
cres: CompileResult = self._compile_core(args, return_type)
146146

147-
kernel_library = kernel_cres.library
148-
kernel_fndesc = kernel_cres.fndesc
149-
kernel_targetctx = kernel_cres.target_context
150-
151-
kernel_module = self._compile_to_spirv(
152-
kernel_library, kernel_fndesc, kernel_targetctx
147+
kernel_device_ir_module = self._compile_to_spirv(
148+
cres.library, cres.fndesc, cres.target_context
153149
)
154150

151+
kcres_attrs = []
152+
153+
for cres_field in cres._fields:
154+
cres_attr = getattr(cres, cres_field)
155+
if cres_field == "entry_point":
156+
if cres_attr is not None:
157+
raise AssertionError(
158+
"Compiled kernel and device_func should be "
159+
"compiled with compile_cfunc option turned off"
160+
)
161+
cres_attr = cres.fndesc.qualname
162+
kcres_attrs.append(cres_attr)
163+
164+
kcres_attrs.append(kernel_device_ir_module)
165+
155166
if config.DUMP_KERNEL_LLVM:
156167
with open(
157-
kernel_cres.fndesc.llvm_func_name + ".ll",
168+
cres.fndesc.llvm_func_name + ".ll",
158169
"w",
159170
encoding="UTF-8",
160171
) as f:
161-
f.write(kernel_cres.library.final_module)
172+
f.write(cres.library.final_module)
162173

163174
except errors.TypingError as e:
164175
self._failed_cache[key] = e
165-
return _KernelCompileResult(False, e, None)
176+
return False, e
166177

167-
return _KernelCompileResult(True, kernel_cres, kernel_module)
178+
return True, _KernelCompileResult(*kcres_attrs)
168179

169180

170181
class KernelDispatcher(Dispatcher):
@@ -234,7 +245,14 @@ def typeof_pyval(self, val):
234245

235246
def add_overload(self, cres):
236247
args = tuple(cres.signature.args)
237-
self.overloads[args] = cres.entry_point
248+
self.overloads[args] = cres
249+
250+
def get_overload_device_ir(self, sig):
251+
"""
252+
Return the compiled device bitcode for the given signature.
253+
"""
254+
args, _ = sigutils.normalize_signature(sig)
255+
return self.overloads[tuple(args)].kernel_device_ir_module
238256

239257
def compile(self, sig) -> _KernelCompileResult:
240258
disp = self._get_dispatcher_for_current_target()
@@ -274,7 +292,7 @@ def cb_llvm(dur):
274292
# Don't recompile if signature already exists
275293
existing = self.overloads.get(tuple(args))
276294
if existing is not None:
277-
return existing
295+
return existing.entry_point
278296

279297
# TODO: Enable caching
280298
# Add code to enable on disk caching of a binary spirv kernel.
@@ -298,7 +316,11 @@ def folded(args, kws):
298316
)[1]
299317

300318
raise e.bind_fold_arguments(folded)
301-
self.add_overload(kcres.cres_or_error)
319+
self.add_overload(kcres)
320+
321+
kcres.target_context.insert_user_function(
322+
kcres.entry_point, kcres.fndesc, [kcres.library]
323+
)
302324

303325
# TODO: enable caching of kernel_module
304326
# https://github.com/IntelPython/numba-dpex/issues/1197
@@ -318,5 +340,5 @@ def __call__(self, *args, **kw_args):
318340
raise NotImplementedError
319341

320342

321-
_dpex_target = target_registry["dpex_kernel"]
343+
_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
322344
dispatcher_registry[_dpex_target] = KernelDispatcher

numba_dpex/experimental/launcher.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from llvmlite import ir as llvmir
1313
from numba.core import cgutils, cpu, types
14+
from numba.core.datamodel import default_manager as numba_default_dmm
1415
from numba.extending import intrinsic, overload
1516

1617
from numba_dpex import config, dpjit
@@ -192,23 +193,27 @@ def create_llvm_values_for_index_space(
192193
ndim = indexer_argty.ndim
193194
grange_extents = []
194195
lrange_extents = []
195-
datamodel = self._kernel_targetctx.data_model_manager.lookup(
196-
indexer_argty
197-
)
196+
indexer_datamodel = numba_default_dmm.lookup(indexer_argty)
198197

199198
if isinstance(indexer_argty, RangeType):
200199
for dim_num in range(ndim):
201-
dim_pos = datamodel.get_field_position("dim" + str(dim_num))
200+
dim_pos = indexer_datamodel.get_field_position(
201+
"dim" + str(dim_num)
202+
)
202203
grange_extents.append(
203204
self._builder.extract_value(index_space_arg, dim_pos)
204205
)
205206
elif isinstance(indexer_argty, NdRangeType):
206207
for dim_num in range(ndim):
207-
gdim_pos = datamodel.get_field_position("gdim" + str(dim_num))
208+
gdim_pos = indexer_datamodel.get_field_position(
209+
"gdim" + str(dim_num)
210+
)
208211
grange_extents.append(
209212
self._builder.extract_value(index_space_arg, gdim_pos)
210213
)
211-
ldim_pos = datamodel.get_field_position("ldim" + str(dim_num))
214+
ldim_pos = indexer_datamodel.get_field_position(
215+
"ldim" + str(dim_num)
216+
)
212217
lrange_extents.append(
213218
self._builder.extract_value(index_space_arg, ldim_pos)
214219
)
@@ -308,7 +313,10 @@ def intrin_launch_trampoline(
308313
sig = types.void(kernel_fn, index_space, kernel_args)
309314
# signature of the kernel_fn
310315
kernel_sig = types.void(*kernel_args_list)
311-
kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig)
316+
kernel_fn.dispatcher.compile(kernel_sig)
317+
kernel_module: _KernelModule = kernel_fn.dispatcher.get_overload_device_ir(
318+
kernel_sig
319+
)
312320
kernel_targetctx = kernel_fn.dispatcher.targetctx
313321

314322
def codegen(cgctx, builder, sig, llargs):
@@ -324,7 +332,7 @@ def codegen(cgctx, builder, sig, llargs):
324332
)
325333

326334
kernel_bc_byte_str = fn_body_gen.insert_kernel_bitcode_as_byte_str(
327-
kmodule
335+
kernel_module
328336
)
329337

330338
populated_kernel_args = (
@@ -341,10 +349,10 @@ def codegen(cgctx, builder, sig, llargs):
341349
kbref = fn_body_gen.create_kernel_bundle_from_spirv(
342350
queue_ref=qref,
343351
kernel_bc=kernel_bc_byte_str,
344-
kernel_bc_size_in_bytes=len(kmodule.kernel_bitcode),
352+
kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode),
345353
)
346354

347-
kref = fn_body_gen.get_kernel(kmodule, kbref)
355+
kref = fn_body_gen.get_kernel(kernel_module, kbref)
348356

349357
index_space_values = fn_body_gen.create_llvm_values_for_index_space(
350358
indexer_argty=sig.args[1],

numba_dpex/experimental/models.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,38 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Provides Numba datamodel for the numba_dpex types introduced in the
5+
"""Provides the Numba data models for the numba_dpex types introduced in the
66
numba_dpex.experimental module.
77
"""
88

9-
from numba.core.datamodel import models
9+
from numba.core.datamodel import DataModelManager, models
1010
from numba.core.extending import register_model
1111

12-
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dmm
12+
import numba_dpex.core.datamodel.models as dpex_core_models
1313

1414
from .types import KernelDispatcherType
1515

16-
# Register the types and datamodel in the DpexKernelTargetContext
17-
dmm.register(KernelDispatcherType, models.OpaqueModel)
1816

19-
# Register the types and datamodel in the DpexTargetContext
17+
def _init_exp_data_model_manager() -> DataModelManager:
18+
"""Initializes a DpexExpKernelTarget-specific data model manager.
19+
20+
Extends the DpexKernelTargetContext's datamodel manager with all
21+
experimental types that are getting added to the kernel API.
22+
23+
Returns:
24+
DataModelManager: A numba-dpex DpexExpKernelTarget-specific data model
25+
manager
26+
"""
27+
28+
dmm = dpex_core_models.dpex_data_model_manager.copy()
29+
30+
# Register the types and data model in the DpexExpTargetContext
31+
# Add here...
32+
33+
return dmm
34+
35+
36+
exp_dmm = _init_exp_data_model_manager()
37+
38+
# Register any new type that should go into numba.core.datamodel.default_manager
2039
register_model(KernelDispatcherType)(models.OpaqueModel)

numba_dpex/experimental/target.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,24 @@
99
from functools import cached_property
1010

1111
from numba.core.descriptors import TargetDescriptor
12+
from numba.core.target_extension import GPU, target_registry
1213

1314
from numba_dpex.core.descriptor import DpexTargetOptions
1415
from numba_dpex.core.targets.kernel_target import (
15-
DPEX_KERNEL_TARGET_NAME,
1616
DpexKernelTargetContext,
1717
DpexKernelTypingContext,
1818
)
19+
from numba_dpex.experimental.models import exp_dmm
20+
21+
22+
# pylint: disable=R0903
23+
class SyclDeviceExp(GPU):
24+
"""Mark the hardware target as SYCL Device."""
25+
26+
27+
DPEX_KERNEL_EXP_TARGET_NAME = "dpex_kernel_exp"
28+
29+
target_registry[DPEX_KERNEL_EXP_TARGET_NAME] = SyclDeviceExp
1930

2031

2132
class DpexExpKernelTypingContext(DpexKernelTypingContext):
@@ -41,6 +52,10 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
4152
they are stable enough to be migrated to DpexKernelTargetContext.
4253
"""
4354

55+
def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
56+
super().__init__(typingctx, target)
57+
self.data_model_manager = exp_dmm
58+
4459

4560
class DpexExpKernelTarget(TargetDescriptor):
4661
"""
@@ -77,4 +92,4 @@ def typing_context(self):
7792

7893

7994
# A global instance of the DpexKernelTarget with the experimental features
80-
dpex_exp_kernel_target = DpexExpKernelTarget(DPEX_KERNEL_TARGET_NAME)
95+
dpex_exp_kernel_target = DpexExpKernelTarget(DPEX_KERNEL_EXP_TARGET_NAME)

0 commit comments

Comments
 (0)