Skip to content

Commit 62b942d

Browse files
author
Diptorup Deb
committed
Add a new device and target string for experimental target.
1 parent ca8e769 commit 62b942d

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

numba_dpex/experimental/decorators.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import inspect
99

1010
from numba.core import sigutils
11-
from numba.core.target_extension import jit_registry, target_registry
11+
from numba.core.target_extension import (
12+
jit_registry,
13+
resolve_dispatcher_from_str,
14+
target_registry,
15+
)
1216

13-
from .kernel_dispatcher import KernelDispatcher
17+
from .target import DPEX_KERNEL_EXP_TARGET_NAME
1418

1519

1620
def kernel(func_or_sig=None, **options):
@@ -24,11 +28,14 @@ def kernel(func_or_sig=None, **options):
2428
* All array arguments passed to a kernel should adhere to compute
2529
follows data programming model.
2630
"""
31+
32+
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
33+
2734
# FIXME: The options need to be evaluated and checked here like it is
2835
# done in numba.core.decorators.jit
2936

3037
def _kernel_dispatcher(pyfunc):
31-
return KernelDispatcher(
38+
return dispatcher(
3239
pyfunc=pyfunc,
3340
targetoptions=options,
3441
)
@@ -59,9 +66,7 @@ def _kernel_dispatcher(pyfunc):
5966
func_or_sig = [func_or_sig]
6067

6168
def _specialized_kernel_dispatcher(pyfunc):
62-
return KernelDispatcher(
63-
pyfunc=pyfunc,
64-
)
69+
return dispatcher(pyfunc=pyfunc)
6570

6671
return _specialized_kernel_dispatcher
6772
func = func_or_sig
@@ -75,4 +80,4 @@ def _specialized_kernel_dispatcher(pyfunc):
7580
return _kernel_dispatcher(func)
7681

7782

78-
jit_registry[target_registry["dpex_kernel"]] = kernel
83+
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = kernel

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from numba_dpex.core.pipelines import kernel_compiler
2525
from numba_dpex.core.types import DpnpNdArray
2626

27-
from .target import dpex_exp_kernel_target
27+
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
2828

2929
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])
3030

@@ -318,5 +318,5 @@ def __call__(self, *args, **kw_args):
318318
raise NotImplementedError
319319

320320

321-
_dpex_target = target_registry["dpex_kernel"]
321+
_dpex_target = target_registry[DPEX_KERNEL_EXP_TARGET_NAME]
322322
dispatcher_registry[_dpex_target] = KernelDispatcher

numba_dpex/experimental/target.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,25 @@
99
from functools import cached_property
1010

1111
from numba.core.descriptors import TargetDescriptor
12+
from numba.core.target_extension import GPU, target_registry
1213

1314
from numba_dpex.core.descriptor import DpexTargetOptions
1415
from numba_dpex.core.targets.kernel_target import (
15-
DPEX_KERNEL_TARGET_NAME,
1616
DpexKernelTargetContext,
1717
DpexKernelTypingContext,
1818
)
1919

2020

21+
# pylint: disable=R0903
22+
class SyclDeviceExp(GPU):
23+
"""Mark the hardware target as SYCL Device."""
24+
25+
26+
DPEX_KERNEL_EXP_TARGET_NAME = "dpex_kernel_exp"
27+
28+
target_registry[DPEX_KERNEL_EXP_TARGET_NAME] = SyclDeviceExp
29+
30+
2131
class DpexExpKernelTypingContext(DpexKernelTypingContext):
2232
"""Experimental typing context class extending the DpexKernelTypingContext
2333
by overriding super class functions for new experimental types.
@@ -77,4 +87,4 @@ def typing_context(self):
7787

7888

7989
# A global instance of the DpexKernelTarget with the experimental features
80-
dpex_exp_kernel_target = DpexExpKernelTarget(DPEX_KERNEL_TARGET_NAME)
90+
dpex_exp_kernel_target = DpexExpKernelTarget(DPEX_KERNEL_EXP_TARGET_NAME)

0 commit comments

Comments
 (0)