Skip to content

Commit b3ae9c2

Browse files
committed
Add item and nd_item style of kernel programing
1 parent cedbfe9 commit b3ae9c2

File tree

13 files changed

+447
-9
lines changed

13 files changed

+447
-9
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
# Temporary so that Range and NdRange work in experimental call_kernel
1212
from numba_dpex.core.boxing import *
1313

14-
from ._kernel_dpcpp_spirv_overloads import _atomic_ref_overloads
14+
from ._kernel_dpcpp_spirv_overloads import (
15+
_atomic_ref_overloads,
16+
_index_space_id_overloads,
17+
)
1518
from .decorators import device_func, kernel
1619
from .kernel_dispatcher import KernelDispatcher
1720
from .launcher import call_kernel, call_kernel_async
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements the SPIR-V overloads for the kernel_api.items class methods.
7+
"""
8+
9+
from numba.core import types
10+
from numba.core.errors import TypingError
11+
from numba.extending import intrinsic, overload_method
12+
13+
from numba_dpex.experimental.core.types.kernel_api.items import (
14+
ItemType,
15+
NdItemType,
16+
)
17+
from numba_dpex.ocl._declare_function import _declare_function
18+
19+
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
20+
21+
22+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
23+
def _intrinsic_get_global_id(
24+
ty_context, ty_dim # pylint: disable=unused-argument
25+
):
26+
"""Generates instruction for spirv i64 get_global_id(i32) call."""
27+
sig = types.int64(types.int32)
28+
29+
def _intrinsic_exchange_gen(context, builder, sig, args):
30+
[dim] = args
31+
get_global_id = _declare_function(
32+
# unsigned int - is what demangler returns from IR instruction
33+
# generated by dpcpp. However int32 is passed as an argument.
34+
# Most likely it does not matter, since argument can be from 0 to 3.
35+
# TODO: https://github.com/IntelPython/numba-dpex/issues/936
36+
# Numba generates unnecessary checks because of the type mismatch.
37+
context,
38+
builder,
39+
"get_global_id",
40+
sig,
41+
["unsigned int"],
42+
)
43+
res = builder.call(get_global_id, [dim])
44+
return context.cast(builder, res, types.uintp, types.intp)
45+
46+
return sig, _intrinsic_exchange_gen
47+
48+
49+
@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
50+
def ol_item_get_id(item, dim):
51+
"""SPIR-V overload for
52+
:meth:`numba_dpex.kernel_api.Item.get_id`.
53+
54+
Generates the same LLVM IR instruction as dpcpp for the
55+
`sycl::item::get_id` function.
56+
57+
Raises:
58+
TypingError: When argument is not an integer.
59+
"""
60+
if not isinstance(item, ItemType):
61+
raise TypingError("Only Item is supported")
62+
63+
if not isinstance(dim, types.Integer):
64+
raise TypingError("Only integers supported")
65+
66+
# pylint: disable=unused-argument
67+
def ol_get_id(item, dim):
68+
# pylint: disable=no-value-for-parameter
69+
return _intrinsic_get_global_id(dim)
70+
71+
return ol_get_id
72+
73+
74+
@overload_method(
75+
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
76+
)
77+
def ol_nd_item_get_global_id(nd_item, dim):
78+
"""SPIR-V overload for
79+
:meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
80+
81+
Generates the same LLVM IR instruction as dpcpp for the
82+
`sycl::nd_item::get_global_id` function.
83+
84+
Raises:
85+
TypingError: When argument is not an integer.
86+
"""
87+
if not isinstance(nd_item, NdItemType):
88+
raise TypingError("Only NdItem is supported")
89+
90+
if not isinstance(dim, types.Integer):
91+
raise TypingError("Only integers supported")
92+
93+
# pylint: disable=unused-argument
94+
def ol_get_global_id(nd_item, dim):
95+
# pylint: disable=no-value-for-parameter
96+
return _intrinsic_get_global_id(dim)
97+
98+
return ol_get_global_id
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Defines numba types for Item and NdItem classes"""
6+
7+
from numba.core import errors, types
8+
9+
10+
class ItemType(types.Type):
11+
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Item`"""
12+
13+
def __init__(self, ndim: int):
14+
self._ndim = ndim
15+
if ndim < 1 or ndim > 3:
16+
raise errors.TypingError(
17+
"ItemType can only have 1, 2 or 3 dimensions"
18+
)
19+
super().__init__(name="Item<" + str(ndim) + ">")
20+
21+
@property
22+
def ndim(self):
23+
"""Returns number of dimensions"""
24+
return self._ndim
25+
26+
@property
27+
def key(self):
28+
"""Numba type specific overload"""
29+
return self._ndim
30+
31+
def cast_python_value(self, args):
32+
raise NotImplementedError
33+
34+
35+
class NdItemType(types.Type):
36+
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.NdItem`"""
37+
38+
def __init__(self, ndim: int):
39+
self._ndim = ndim
40+
if ndim < 1 or ndim > 3:
41+
raise errors.TypingError(
42+
"ItemType can only have 1, 2 or 3 dimensions"
43+
)
44+
super().__init__(name="NdItem<" + str(ndim) + ">")
45+
46+
@property
47+
def ndim(self):
48+
"""Returns number of dimensions"""
49+
return self._ndim
50+
51+
@property
52+
def key(self):
53+
"""Numba type specific overload"""
54+
return self._ndim
55+
56+
def cast_python_value(self, args):
57+
raise NotImplementedError

numba_dpex/experimental/launcher.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from either CPython or a numba_dpex.dpjit decorated function.
77
"""
88

9+
import warnings
10+
from inspect import signature
911
from typing import NamedTuple, Union
1012

1113
import dpctl
@@ -23,6 +25,10 @@
2325
from numba_dpex.core.utils import kernel_launcher as kl
2426
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2527
from numba_dpex.dpctl_iface.wrappers import wrap_event_reference
28+
from numba_dpex.experimental.core.types.kernel_api.items import (
29+
ItemType,
30+
NdItemType,
31+
)
2632
from numba_dpex.experimental.kernel_dispatcher import (
2733
KernelDispatcher,
2834
_KernelCompileResult,
@@ -129,13 +135,30 @@ def _submit_kernel( # pylint: disable=too-many-arguments
129135
else:
130136
sig = ty_return(ty_kernel_fn, ty_index_space, ty_kernel_args_tuple)
131137

132-
kernel_sig = types.void(*ty_kernel_args_tuple)
138+
# Add Item/NdItem as a first argument to kernel arguments list. It is
139+
# an empty struct so any other modifications at kernel submission are not
140+
# needed.
141+
if len(signature(ty_kernel_fn.dispatcher.py_func).parameters) > len(
142+
ty_kernel_args_tuple
143+
):
144+
if isinstance(ty_index_space, RangeType):
145+
ty_item = ItemType(ty_index_space.ndim)
146+
else:
147+
ty_item = NdItemType(ty_index_space.ndim)
148+
149+
ty_kernel_args_tuple = (ty_item, *ty_kernel_args_tuple)
150+
else:
151+
warnings.warn(
152+
"Kernels without item/nd_item will be not supported in the future",
153+
DeprecationWarning,
154+
)
155+
133156
# ty_kernel_fn is type specific to exact function, so we can get function
134157
# directly from type and compile it. Thats why we don't need to get it in
135158
# codegen
136159
kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher
137160
kcres: _KernelCompileResult = kernel_dispatcher.get_compile_result(
138-
kernel_sig
161+
types.void(*ty_kernel_args_tuple) # kernel signature
139162
)
140163
kernel_module: kl.SPIRVKernelModule = kcres.kernel_device_ir_module
141164
kernel_targetctx: DpexKernelTargetContext = kernel_dispatcher.targetctx

numba_dpex/experimental/models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from numba.core.extending import register_model
1414

1515
import numba_dpex.core.datamodel.models as dpex_core_models
16+
from numba_dpex.experimental.core.types.kernel_api.items import (
17+
ItemType,
18+
NdItemType,
19+
)
1620

1721
from .dpcpp_types import AtomicRefType
1822
from .literal_intenum_type import IntEnumLiteral
@@ -43,6 +47,15 @@ def __init__(self, dmm, fe_type):
4347
super().__init__(dmm, fe_type, be_type)
4448

4549

50+
class EmptyStructModel(StructModel):
51+
"""Data model that does not take space. Intended to be used with types that
52+
are presented only at typing stage and not represented physically."""
53+
54+
def __init__(self, dmm, fe_type):
55+
members = []
56+
super().__init__(dmm, fe_type, members)
57+
58+
4659
def _init_exp_data_model_manager() -> DataModelManager:
4760
"""Initializes a DpexExpKernelTarget-specific data model manager.
4861
@@ -60,10 +73,22 @@ def _init_exp_data_model_manager() -> DataModelManager:
6073
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
6174
dmm.register(AtomicRefType, AtomicRefModel)
6275

76+
# Register the ItemType type
77+
dmm.register(ItemType, EmptyStructModel)
78+
79+
# Register the NdItemType type
80+
dmm.register(NdItemType, EmptyStructModel)
81+
6382
return dmm
6483

6584

6685
exp_dmm = _init_exp_data_model_manager()
6786

6887
# Register any new type that should go into numba.core.datamodel.default_manager
6988
register_model(KernelDispatcherType)(models.OpaqueModel)
89+
90+
# Register the ItemType type
91+
register_model(ItemType)(EmptyStructModel)
92+
93+
# Register the NdItemType type
94+
register_model(NdItemType)(EmptyStructModel)

numba_dpex/experimental/typeof.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
from numba.extending import typeof_impl
1111

12-
from numba_dpex.kernel_api import AtomicRef
12+
from numba_dpex.experimental.core.types.kernel_api.items import (
13+
ItemType,
14+
NdItemType,
15+
)
16+
from numba_dpex.kernel_api import AtomicRef, Item, NdItem
1317

1418
from .dpcpp_types import AtomicRefType
1519

@@ -34,3 +38,33 @@ def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType:
3438
memory_scope=val.memory_scope.value,
3539
address_space=val.address_space.value,
3640
)
41+
42+
43+
@typeof_impl.register(Item)
44+
def typeof_item(val: Item, c):
45+
"""Registers the type inference implementation function for a
46+
numba_dpex.kernel_api.Item PyObject.
47+
48+
Args:
49+
val : An instance of numba_dpex.kernel_api.Item.
50+
c : Unused argument used to be consistent with Numba API.
51+
52+
Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType
53+
instance.
54+
"""
55+
return ItemType(val.ndim)
56+
57+
58+
@typeof_impl.register(NdItem)
59+
def typeof_nditem(val, c):
60+
"""Registers the type inference implementation function for a
61+
numba_dpex.kernel_api.NdItem PyObject.
62+
63+
Args:
64+
val : An instance of numba_dpex.kernel_api.NdItem.
65+
c : Unused argument used to be consistent with Numba API.
66+
67+
Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType
68+
instance.
69+
"""
70+
return NdItemType(val.ndim)

numba_dpex/kernel_api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
from .atomic_ref import AtomicRef
13+
from .index_space_ids import Item, NdItem
1314
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
1415
from .ranges import NdRange, Range
1516

@@ -20,4 +21,6 @@
2021
"MemoryScope",
2122
"NdRange",
2223
"Range",
24+
"NdItem",
25+
"Item",
2326
]

0 commit comments

Comments
 (0)