Skip to content

Commit 53bb704

Browse files
authored
Merge pull request #1412 from IntelPython/merge_exp_spv_targt_into_spv_target
Completely remove numba_dpex.experimental module
2 parents 8e3b63d + 7729548 commit 53bb704

22 files changed

+228
-432
lines changed

numba_dpex/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from .kernel_api_impl.spirv import target as spirv_kernel_target
2020
from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
21+
from .register_kernel_api_overloads import init_kernel_api_spirv_overloads
2122

2223

2324
def load_dpctl_sycl_interface():
@@ -136,11 +137,16 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
136137
__version__ = get_versions()["version"]
137138
del get_versions
138139

140+
# Initialize the kernel_api SPIRV overloads
141+
init_kernel_api_spirv_overloads()
142+
139143
__all__ = types.__all__ + [
140144
"call_kernel",
145+
"call_kernel_async",
141146
"device_func",
142147
"dpjit",
143148
"kernel",
149+
"prange",
144150
"Range",
145151
"NdRange",
146152
]

numba_dpex/core/decorators.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
)
1414

1515
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
16-
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
1716
from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher
18-
from numba_dpex.kernel_api_impl.spirv.target import CompilationMode
17+
from numba_dpex.kernel_api_impl.spirv.target import (
18+
SPIRV_TARGET_NAME,
19+
CompilationMode,
20+
)
1921

2022

2123
def _parse_func_or_sig(signature_or_function):
@@ -154,7 +156,7 @@ def vecadd(item: kapi.Item, a, b, c):
154156

155157
# dispatcher is a type:
156158
# <class 'numba_dpex.experimental.kernel_dispatcher.KernelDispatcher'>
157-
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
159+
dispatcher = resolve_dispatcher_from_str(SPIRV_TARGET_NAME)
158160
if "_compilation_mode" in options:
159161
user_compilation_mode = options["_compilation_mode"]
160162
warn(
@@ -280,7 +282,7 @@ def another_kernel(nd_item: NdItem, a):
280282
281283
dpex_exp.call_kernel(another_kernel, dpex.NdRange((N,), (N,)), b)
282284
"""
283-
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
285+
dispatcher = resolve_dispatcher_from_str(SPIRV_TARGET_NAME)
284286

285287
if "_compilation_mode" in options:
286288
user_compilation_mode = options["_compilation_mode"]
@@ -342,4 +344,4 @@ def dpjit(*args, **kws):
342344
# add it to the decorator registry, this is so e.g. @overload can look up a
343345
# JIT function to do the compilation work.
344346
jit_registry[target_registry[DPEX_TARGET_NAME]] = dpjit
345-
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = device_func
347+
jit_registry[target_registry[SPIRV_TARGET_NAME]] = device_func

numba_dpex/core/typing/typeof.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
from numba.extending import typeof_impl
99
from numba.np import numpy_support
1010

11+
from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem
1112
from numba_dpex.kernel_api.ranges import NdRange, Range
1213
from numba_dpex.utils.constants import address_space
1314

1415
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
1516
from ..types.dpnp_ndarray_type import DpnpNdArray
17+
from ..types.kernel_api.atomic_ref import AtomicRefType
18+
from ..types.kernel_api.index_space_ids import GroupType, ItemType, NdItemType
19+
from ..types.kernel_api.local_accessor import LocalAccessorType
1620
from ..types.kernel_api.ranges import NdRangeType, RangeType
1721
from ..types.usm_ndarray_type import USMNdArray
1822

@@ -150,3 +154,83 @@ def typeof_ndrange(val, c):
150154
Returns: A numba_dpex.core.types.range_types.RangeType instance.
151155
"""
152156
return NdRangeType(val.global_range.ndim)
157+
158+
159+
@typeof_impl.register(AtomicRef)
160+
def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType:
161+
"""Returns a ``numba_dpex.experimental.dpctpp_types.AtomicRefType``
162+
instance for a Python AtomicRef object.
163+
164+
Args:
165+
val (AtomicRef): Instance of the AtomicRef type.
166+
ctx : Numba typing context used for type inference.
167+
168+
Returns: AtomicRefType object corresponding to the AtomicRef object.
169+
170+
"""
171+
dtype = typeof_impl(val.ref, ctx)
172+
173+
return AtomicRefType(
174+
dtype=dtype,
175+
memory_order=val.memory_order.value,
176+
memory_scope=val.memory_scope.value,
177+
address_space=val.address_space.value,
178+
)
179+
180+
181+
@typeof_impl.register(Group)
182+
def typeof_group(val: Group, c):
183+
"""Registers the type inference implementation function for a
184+
numba_dpex.kernel_api.Group PyObject.
185+
186+
Args:
187+
val : An instance of numba_dpex.kernel_api.Group.
188+
c : Unused argument used to be consistent with Numba API.
189+
190+
Returns: A numba_dpex.experimental.core.types.kernel_api.items.GroupType
191+
instance.
192+
"""
193+
return GroupType(val.ndim)
194+
195+
196+
@typeof_impl.register(Item)
197+
def typeof_item(val: Item, c):
198+
"""Registers the type inference implementation function for a
199+
numba_dpex.kernel_api.Item PyObject.
200+
201+
Args:
202+
val : An instance of numba_dpex.kernel_api.Item.
203+
c : Unused argument used to be consistent with Numba API.
204+
205+
Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType
206+
instance.
207+
"""
208+
return ItemType(val.dimensions)
209+
210+
211+
@typeof_impl.register(NdItem)
212+
def typeof_nditem(val: NdItem, c):
213+
"""Registers the type inference implementation function for a
214+
numba_dpex.kernel_api.NdItem PyObject.
215+
216+
Args:
217+
val : An instance of numba_dpex.kernel_api.NdItem.
218+
c : Unused argument used to be consistent with Numba API.
219+
220+
Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType
221+
instance.
222+
"""
223+
return NdItemType(val.dimensions)
224+
225+
226+
@typeof_impl.register(LocalAccessor)
227+
def typeof_local_accessor(val: LocalAccessor, c) -> LocalAccessorType:
228+
"""Returns a ``numba_dpex.experimental.dpctpp_types.LocalAccessorType``
229+
instance for a Python LocalAccessor object.
230+
Args:
231+
val (LocalAccessor): Instance of the LocalAccessor type.
232+
c : Numba typing context used for type inference.
233+
Returns: LocalAccessorType object corresponding to the LocalAccessor object.
234+
"""
235+
# pylint: disable=protected-access
236+
return LocalAccessorType(ndim=len(val._shape), dtype=val._dtype)

numba_dpex/experimental/__init__.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

numba_dpex/experimental/models.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

numba_dpex/experimental/target.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

numba_dpex/experimental/testing.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)