|
8 | 8 | from numba.extending import typeof_impl
|
9 | 9 | from numba.np import numpy_support
|
10 | 10 |
|
| 11 | +from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem |
11 | 12 | from numba_dpex.kernel_api.ranges import NdRange, Range
|
12 | 13 | from numba_dpex.utils.constants import address_space
|
13 | 14 |
|
14 | 15 | from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
|
15 | 16 | from ..types.dpnp_ndarray_type import DpnpNdArray
|
| 17 | +from ..types.kernel_api.atomic_ref import AtomicRefType |
| 18 | +from ..types.kernel_api.index_space_ids import GroupType, ItemType, NdItemType |
| 19 | +from ..types.kernel_api.local_accessor import LocalAccessorType |
16 | 20 | from ..types.kernel_api.ranges import NdRangeType, RangeType
|
17 | 21 | from ..types.usm_ndarray_type import USMNdArray
|
18 | 22 |
|
@@ -150,3 +154,83 @@ def typeof_ndrange(val, c):
|
150 | 154 | Returns: A numba_dpex.core.types.range_types.RangeType instance.
|
151 | 155 | """
|
152 | 156 | return NdRangeType(val.global_range.ndim)
|
| 157 | + |
| 158 | + |
| 159 | +@typeof_impl.register(AtomicRef) |
| 160 | +def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType: |
| 161 | + """Returns a ``numba_dpex.experimental.dpctpp_types.AtomicRefType`` |
| 162 | + instance for a Python AtomicRef object. |
| 163 | +
|
| 164 | + Args: |
| 165 | + val (AtomicRef): Instance of the AtomicRef type. |
| 166 | + ctx : Numba typing context used for type inference. |
| 167 | +
|
| 168 | + Returns: AtomicRefType object corresponding to the AtomicRef object. |
| 169 | +
|
| 170 | + """ |
| 171 | + dtype = typeof_impl(val.ref, ctx) |
| 172 | + |
| 173 | + return AtomicRefType( |
| 174 | + dtype=dtype, |
| 175 | + memory_order=val.memory_order.value, |
| 176 | + memory_scope=val.memory_scope.value, |
| 177 | + address_space=val.address_space.value, |
| 178 | + ) |
| 179 | + |
| 180 | + |
| 181 | +@typeof_impl.register(Group) |
| 182 | +def typeof_group(val: Group, c): |
| 183 | + """Registers the type inference implementation function for a |
| 184 | + numba_dpex.kernel_api.Group PyObject. |
| 185 | +
|
| 186 | + Args: |
| 187 | + val : An instance of numba_dpex.kernel_api.Group. |
| 188 | + c : Unused argument used to be consistent with Numba API. |
| 189 | +
|
| 190 | + Returns: A numba_dpex.experimental.core.types.kernel_api.items.GroupType |
| 191 | + instance. |
| 192 | + """ |
| 193 | + return GroupType(val.ndim) |
| 194 | + |
| 195 | + |
| 196 | +@typeof_impl.register(Item) |
| 197 | +def typeof_item(val: Item, c): |
| 198 | + """Registers the type inference implementation function for a |
| 199 | + numba_dpex.kernel_api.Item PyObject. |
| 200 | +
|
| 201 | + Args: |
| 202 | + val : An instance of numba_dpex.kernel_api.Item. |
| 203 | + c : Unused argument used to be consistent with Numba API. |
| 204 | +
|
| 205 | + Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType |
| 206 | + instance. |
| 207 | + """ |
| 208 | + return ItemType(val.dimensions) |
| 209 | + |
| 210 | + |
| 211 | +@typeof_impl.register(NdItem) |
| 212 | +def typeof_nditem(val: NdItem, c): |
| 213 | + """Registers the type inference implementation function for a |
| 214 | + numba_dpex.kernel_api.NdItem PyObject. |
| 215 | +
|
| 216 | + Args: |
| 217 | + val : An instance of numba_dpex.kernel_api.NdItem. |
| 218 | + c : Unused argument used to be consistent with Numba API. |
| 219 | +
|
| 220 | + Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType |
| 221 | + instance. |
| 222 | + """ |
| 223 | + return NdItemType(val.dimensions) |
| 224 | + |
| 225 | + |
| 226 | +@typeof_impl.register(LocalAccessor) |
| 227 | +def typeof_local_accessor(val: LocalAccessor, c) -> LocalAccessorType: |
| 228 | + """Returns a ``numba_dpex.experimental.dpctpp_types.LocalAccessorType`` |
| 229 | + instance for a Python LocalAccessor object. |
| 230 | + Args: |
| 231 | + val (LocalAccessor): Instance of the LocalAccessor type. |
| 232 | + c : Numba typing context used for type inference. |
| 233 | + Returns: LocalAccessorType object corresponding to the LocalAccessor object. |
| 234 | + """ |
| 235 | + # pylint: disable=protected-access |
| 236 | + return LocalAccessorType(ndim=len(val._shape), dtype=val._dtype) |
0 commit comments