Skip to content

Commit 73fbbd7

Browse files
author
Diptorup Deb
committed
Extend the mock group class.
1 parent f3113f6 commit 73fbbd7

File tree

1 file changed

+79
-2
lines changed

1 file changed

+79
-2
lines changed

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .ranges import Range
1111

1212

13-
# pylint: disable=too-few-public-methods
1413
class Group:
1514
"""Analogue to the ``sycl::group`` type."""
1615

@@ -25,6 +24,82 @@ def __init__(
2524
self._local_range = local_range
2625
self._group_range = group_range
2726
self._index = index
27+
self._leader = False
28+
29+
def get_group_id(self, dim):
30+
"""Returns the index of the work-group within the global nd-range for
31+
specified dimension.
32+
33+
Since the work-items in a work-group have a defined position within the
34+
global nd-range, the returned group id can be used along with the local
35+
id to uniquely identify the work-item in the global nd-range.
36+
"""
37+
if dim > len(self._index) - 1:
38+
raise ValueError(
39+
"Dimension value is out of bounds for the group index"
40+
)
41+
return self._index[dim]
42+
43+
def get_group_linear_id(self):
44+
"""Returns a linearized version of the work-group index."""
45+
if len(self._index) == 1:
46+
return self._index[0]
47+
if len(self._index) == 2:
48+
return self._index[0] * self._group_range[1] + self._index[1]
49+
return (
50+
(self._index[0] * self._group_range[1] * self._group_range[2])
51+
+ (self._index[1] * self._group_range[2])
52+
+ (self._index[2])
53+
)
54+
55+
def get_group_range(self):
56+
"""Returns a range representing the number of groups in the nd-range."""
57+
return self._group_range
58+
59+
def get_group_linear_range(self):
60+
"""Return the total number of work-groups in the nd_range."""
61+
num_wg = 1
62+
for ext in self._group_range:
63+
num_wg *= ext
64+
65+
return num_wg
66+
67+
def get_local_range(self):
68+
"""Returns a SYCL range representing all dimensions of the local
69+
range. This local range may have been provided by the programmer, or
70+
chosen by the SYCL runtime.
71+
"""
72+
return self._local_range
73+
74+
def get_local_linear_range(self):
75+
"""Return the total number of work-items in the work-group."""
76+
num_wi = 1
77+
for ext in self._local_range:
78+
num_wi *= ext
79+
80+
return num_wi
81+
82+
@property
83+
def leader(self):
84+
"""Return true for exactly one work-item in the work-group, if the
85+
calling work-item is the leader of the work-group, and false for all
86+
other work-items in the work-group.
87+
88+
The leader of the work-group is determined during construction of the
89+
work-group, and is invariant for the lifetime of the work-group. The
90+
leader of the work-group is guaranteed to be the work-item with a
91+
local id of 0.
92+
93+
94+
Returns:
95+
bool: If the work item is the designated leader of the
96+
"""
97+
return self._leader
98+
99+
@leader.setter
100+
def leader(self, work_item_id):
101+
"""Sets the leader attribute for the group."""
102+
self._leader = work_item_id
28103

29104

30105
class Item:
@@ -45,7 +120,7 @@ def get_linear_id(self):
45120
"""
46121
if len(self._extent) == 1:
47122
return self._index[0]
48-
if len(self._extent) == 1:
123+
if len(self._extent) == 2:
49124
return self._index[0] * self._extent[1] + self._index[1]
50125
return (
51126
(self._index[0] * self._extent[1] * self._extent[2])
@@ -90,6 +165,8 @@ def __init__(self, global_item: Item, local_item: Item, group: Group):
90165
self._global_item = global_item
91166
self._local_item = local_item
92167
self._group = group
168+
if self.get_local_linear_id() == 0:
169+
self._group.leader = True
93170

94171
def get_global_id(self, idx):
95172
"""Get the global id for a specific dimension.

0 commit comments

Comments
 (0)