Skip to content

Commit ccb7606

Browse files
author
Diptorup Deb
authored
Merge pull request #1310 from IntelPython/feature/nd_item_get_group
Add Group and overload NdItem.get_group()
2 parents 5374eb4 + dc6c1b0 commit ccb7606

File tree

9 files changed

+163
-30
lines changed

9 files changed

+163
-30
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_group_barrier_overloads.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010

1111
from llvmlite import ir as llvmir
1212
from numba.core import cgutils, types
13+
from numba.core.errors import TypingError
1314
from numba.extending import intrinsic, overload
1415

1516
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
17+
from numba_dpex.experimental.core.types.kernel_api.items import GroupType
1618
from numba_dpex.experimental.target import DPEX_KERNEL_EXP_TARGET_NAME
1719
from numba_dpex.kernel_api import group_barrier
1820
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
@@ -26,7 +28,7 @@
2628
except ValueError:
2729
warnings.warn(
2830
"convergent attribute is supported only starting llvmlite "
29-
+ "0.42. Not setting this attribute may result into unexpected behavior"
31+
+ "0.42. Not setting this attribute may result in unexpected behavior"
3032
+ "when using group_barrier"
3133
)
3234
_SUPPORT_CONVERGENT = False
@@ -76,7 +78,6 @@ def _intrinsic_barrier_codegen(
7678
llvmir.IntType(32),
7779
]
7880

79-
# TODO: split the function declaration from call
8081
fn = cgutils.get_or_insert_function(
8182
builder.module,
8283
llvmir.FunctionType(llvmir.VoidType(), spirv_fn_arg_types),
@@ -105,34 +106,39 @@ def _intrinsic_barrier_codegen(
105106
prefer_literal=True,
106107
target=DPEX_KERNEL_EXP_TARGET_NAME,
107108
)
108-
def ol_group_barrier(fence_scope=MemoryScope.WORK_GROUP):
109+
def ol_group_barrier(group, fence_scope=MemoryScope.WORK_GROUP):
109110
"""SPIR-V overload for
110111
:meth:`numba_dpex.kernel_api.group_barrier`.
111112
112-
Generates the same LLVM IR instruction as dpcpp for the
113+
Generates the same LLVM IR instruction as DPC++ for the SYCL
113114
`group_barrier` function.
114115
115116
Per SYCL spec, group_barrier must perform both control barrier and memory
116-
fence operations. Hence, group_barrier requires two scopes and memory
117-
consistency specification as three arguments.
117+
fence operations. Hence, group_barrier requires two scopes and one memory
118+
consistency specification as its three arguments.
118119
119120
mem_scope - scope of any memory consistency operations that are performed by
120121
the barrier. By default, mem_scope is set to `work_group`.
121122
exec_scope - scope that determines the set of work-items that synchronize at
122123
barrier. Set to `work_group` for group_barrier always.
123-
spirv_memory_semantics_mask - Based on sycl implementation.
124-
125-
Mask that is set to use sequential consistency memory order semantics
126-
always.
124+
spirv_memory_semantics_mask - Based on SYCL implementation. Always set to
125+
use sequential consistency memory order.
127126
"""
128127

128+
if not isinstance(group, GroupType):
129+
raise TypingError(
130+
"Expected a group should to be a GroupType value, but "
131+
f"encountered {type(group)}"
132+
)
133+
129134
mem_scope = _get_memory_scope(fence_scope)
130135
exec_scope = get_scope(MemoryScope.WORK_GROUP.value)
131136
spirv_memory_semantics_mask = get_memory_semantics_mask(
132137
MemoryOrder.SEQ_CST.value
133138
)
134139

135140
def _ol_group_barrier_impl(
141+
group,
136142
fence_scope=MemoryScope.WORK_GROUP,
137143
): # pylint: disable=unused-argument
138144
# pylint: disable=no-value-for-parameter

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
Implements the SPIR-V overloads for the kernel_api.items class methods.
77
"""
88

9-
from numba.core import types
9+
from numba.core import cgutils, types
1010
from numba.core.errors import TypingError
1111
from numba.extending import intrinsic, overload_method
1212

1313
from numba_dpex.experimental.core.types.kernel_api.items import (
14+
GroupType,
1415
ItemType,
1516
NdItemType,
1617
)
@@ -26,7 +27,7 @@ def _intrinsic_get_global_id(
2627
"""Generates instruction for spirv i64 get_global_id(i32) call."""
2728
sig = types.int64(types.int32)
2829

29-
def _intrinsic_exchange_gen(context, builder, sig, args):
30+
def _intrinsic_get_global_id_gen(context, builder, sig, args):
3031
[dim] = args
3132
get_global_id = _declare_function(
3233
# unsigned int - is what demangler returns from IR instruction
@@ -43,13 +44,35 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
4344
res = builder.call(get_global_id, [dim])
4445
return context.cast(builder, res, types.uintp, types.intp)
4546

46-
return sig, _intrinsic_exchange_gen
47+
return sig, _intrinsic_get_global_id_gen
48+
49+
50+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
51+
def _intrinsic_get_group(
52+
ty_context, ty_nd_item: NdItemType # pylint: disable=unused-argument
53+
):
54+
"""Generates group with a dimension of nd_item."""
55+
56+
if not isinstance(ty_nd_item, NdItemType):
57+
raise TypingError(
58+
f"Expected an NdItemType value, but encountered {ty_nd_item}"
59+
)
60+
61+
ty_group = GroupType(ty_nd_item.ndim)
62+
sig = ty_group(ty_nd_item)
63+
64+
# pylint: disable=unused-argument
65+
def _intrinsic_get_group_gen(context, builder, sig, args):
66+
group_struct = cgutils.create_struct_proxy(ty_group)(context, builder)
67+
# pylint: disable=protected-access
68+
return group_struct._getvalue()
69+
70+
return sig, _intrinsic_get_group_gen
4771

4872

4973
@overload_method(ItemType, "get_id", target=DPEX_KERNEL_EXP_TARGET_NAME)
5074
def ol_item_get_id(item, dim):
51-
"""SPIR-V overload for
52-
:meth:`numba_dpex.kernel_api.Item.get_id`.
75+
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.Item.get_id`.
5376
5477
Generates the same LLVM IR instruction as dpcpp for the
5578
`sycl::item::get_id` function.
@@ -58,25 +81,30 @@ def ol_item_get_id(item, dim):
5881
TypingError: When argument is not an integer.
5982
"""
6083
if not isinstance(item, ItemType):
61-
raise TypingError("Only Item is supported")
84+
raise TypingError(
85+
"Expected an item should to be an Item value, but "
86+
f"encountered {type(item)}"
87+
)
6288

6389
if not isinstance(dim, types.Integer):
64-
raise TypingError("Only integers supported")
90+
raise TypingError(
91+
"Expected an Item's dim should to be an Integer value, but "
92+
f"encountered {type(dim)}"
93+
)
6594

6695
# pylint: disable=unused-argument
67-
def ol_get_id(item, dim):
96+
def ol_item_get_id_impl(item, dim):
6897
# pylint: disable=no-value-for-parameter
6998
return _intrinsic_get_global_id(dim)
7099

71-
return ol_get_id
100+
return ol_item_get_id_impl
72101

73102

74103
@overload_method(
75104
NdItemType, "get_global_id", target=DPEX_KERNEL_EXP_TARGET_NAME
76105
)
77106
def ol_nd_item_get_global_id(nd_item, dim):
78-
"""SPIR-V overload for
79-
:meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
107+
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_global_id`.
80108
81109
Generates the same LLVM IR instruction as dpcpp for the
82110
`sycl::nd_item::get_global_id` function.
@@ -85,14 +113,46 @@ def ol_nd_item_get_global_id(nd_item, dim):
85113
TypingError: When argument is not an integer.
86114
"""
87115
if not isinstance(nd_item, NdItemType):
88-
raise TypingError("Only NdItem is supported")
116+
# since it is a method overload, this error should not be reached
117+
raise TypingError(
118+
"Expected a nd_item should to be a NdItem value, but "
119+
f"encountered {type(nd_item)}"
120+
)
89121

90122
if not isinstance(dim, types.Integer):
91-
raise TypingError("Only integers supported")
123+
raise TypingError(
124+
"Expected a NdItem's dim should to be an Integer value, but "
125+
f"encountered {type(dim)}"
126+
)
92127

93128
# pylint: disable=unused-argument
94-
def ol_get_global_id(nd_item, dim):
129+
def ol_nd_item_get_global_id_impl(nd_item, dim):
95130
# pylint: disable=no-value-for-parameter
96131
return _intrinsic_get_global_id(dim)
97132

98-
return ol_get_global_id
133+
return ol_nd_item_get_global_id_impl
134+
135+
136+
@overload_method(NdItemType, "get_group", target=DPEX_KERNEL_EXP_TARGET_NAME)
137+
def ol_nd_item_get_group(nd_item):
138+
"""SPIR-V overload for :meth:`numba_dpex.kernel_api.NdItem.get_group`.
139+
140+
Generates the same LLVM IR instruction as dpcpp for the
141+
`sycl::nd_item::get_group` function.
142+
143+
Raises:
144+
TypingError: When argument is not NdItem.
145+
"""
146+
if not isinstance(nd_item, NdItemType):
147+
# since it is a method overload, this error should not be reached
148+
raise TypingError(
149+
"Expected a nd_item should to be a NdItem value, but "
150+
f"encountered {type(nd_item)}"
151+
)
152+
153+
# pylint: disable=unused-argument
154+
def ol_nd_item_get_group_impl(nd_item):
155+
# pylint: disable=no-value-for-parameter
156+
return _intrinsic_get_group(nd_item)
157+
158+
return ol_nd_item_get_group_impl

numba_dpex/experimental/core/types/kernel_api/items.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@
77
from numba.core import errors, types
88

99

10+
class GroupType(types.Type):
11+
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Group`"""
12+
13+
def __init__(self, ndim: int):
14+
self._ndim = ndim
15+
if ndim < 1 or ndim > 3:
16+
raise errors.TypingError(
17+
"ItemType can only have 1, 2 or 3 dimensions"
18+
)
19+
super().__init__(name="Group<" + str(ndim) + ">")
20+
21+
@property
22+
def ndim(self):
23+
"""Returns number of dimensions"""
24+
return self._ndim
25+
26+
@property
27+
def key(self):
28+
"""Numba type specific overload"""
29+
return self._ndim
30+
31+
def cast_python_value(self, args):
32+
raise NotImplementedError
33+
34+
1035
class ItemType(types.Type):
1136
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Item`"""
1237

numba_dpex/experimental/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numba_dpex.core.datamodel.models as dpex_core_models
1616
from numba_dpex.experimental.core.types.kernel_api.items import (
17+
GroupType,
1718
ItemType,
1819
NdItemType,
1920
)
@@ -73,6 +74,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
7374
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
7475
dmm.register(AtomicRefType, AtomicRefModel)
7576

77+
# Register the GroupType type
78+
dmm.register(GroupType, EmptyStructModel)
79+
7680
# Register the ItemType type
7781
dmm.register(ItemType, EmptyStructModel)
7882

@@ -87,6 +91,9 @@ def _init_exp_data_model_manager() -> DataModelManager:
8791
# Register any new type that should go into numba.core.datamodel.default_manager
8892
register_model(KernelDispatcherType)(models.OpaqueModel)
8993

94+
# Register the GroupType type
95+
register_model(GroupType)(EmptyStructModel)
96+
9097
# Register the ItemType type
9198
register_model(ItemType)(EmptyStructModel)
9299

numba_dpex/experimental/typeof.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from numba.extending import typeof_impl
1111

1212
from numba_dpex.experimental.core.types.kernel_api.items import (
13+
GroupType,
1314
ItemType,
1415
NdItemType,
1516
)
16-
from numba_dpex.kernel_api import AtomicRef, Item, NdItem
17+
from numba_dpex.kernel_api import AtomicRef, Group, Item, NdItem
1718

1819
from .dpcpp_types import AtomicRefType
1920

@@ -40,6 +41,21 @@ def typeof_atomic_ref(val: AtomicRef, ctx) -> AtomicRefType:
4041
)
4142

4243

44+
@typeof_impl.register(Group)
45+
def typeof_group(val: Group, c):
46+
"""Registers the type inference implementation function for a
47+
numba_dpex.kernel_api.Group PyObject.
48+
49+
Args:
50+
val : An instance of numba_dpex.kernel_api.Group.
51+
c : Unused argument used to be consistent with Numba API.
52+
53+
Returns: A numba_dpex.experimental.core.types.kernel_api.items.GroupType
54+
instance.
55+
"""
56+
return GroupType(val.ndim)
57+
58+
4359
@typeof_impl.register(Item)
4460
def typeof_item(val: Item, c):
4561
"""Registers the type inference implementation function for a

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .atomic_ref import AtomicRef
1313
from .barrier import group_barrier
14-
from .index_space_ids import Item, NdItem
14+
from .index_space_ids import Group, Item, NdItem
1515
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
1616
from .ranges import NdRange, Range
1717

@@ -22,6 +22,7 @@
2222
"MemoryScope",
2323
"NdRange",
2424
"Range",
25+
"Group",
2526
"NdItem",
2627
"Item",
2728
"group_barrier",

numba_dpex/kernel_api/barrier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
"""Python functions that simulate SYCL's barrier primitives.
66
"""
77

8+
from .index_space_ids import Group
89
from .memory_enums import MemoryScope
910

1011

11-
def group_barrier(fence_scope=MemoryScope.WORK_GROUP):
12+
def group_barrier(group: Group, fence_scope=MemoryScope.WORK_GROUP):
1213
"""Performs a barrier operation across all work-items in a work group.
1314
1415
The function is modeled after the ``sycl::group_barrier`` function. It

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
from .ranges import Range
1111

1212

13+
# pylint: disable=too-few-public-methods
14+
class Group:
15+
"""Analogue to the ``sycl::group`` type."""
16+
17+
def __init__(
18+
self,
19+
global_range: Range,
20+
local_range: Range,
21+
group_range: Range,
22+
index: list,
23+
):
24+
self._global_range = global_range
25+
self._local_range = local_range
26+
self._group_range = group_range
27+
self._index = index
28+
29+
1330
class Item:
1431
"""Analogue to the ``sycl::item`` type. Identifies an instance of the
1532
function object executing at each point in an Range.
@@ -60,7 +77,7 @@ class NdItem:
6077
"""
6178

6279
# TODO: define group type
63-
def __init__(self, global_item: Item, local_item: Item, group: any):
80+
def __init__(self, global_item: Item, local_item: Item, group: Group):
6481
# TODO: assert offset and dimensions
6582
self._global_item = global_item
6683
self._local_item = local_item

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _kernel(nd_item: NdItem, a):
1616
i = nd_item.get_global_id(0)
1717

1818
a[i] += 1
19-
group_barrier(MemoryScope.DEVICE)
19+
group_barrier(nd_item.get_group(), MemoryScope.DEVICE)
2020

2121
if i == 0:
2222
for idx in range(1, a.size):

0 commit comments

Comments
 (0)