10
10
from .ranges import Range
11
11
12
12
13
- # pylint: disable=too-few-public-methods
14
13
class Group :
15
14
"""Analogue to the ``sycl::group`` type."""
16
15
@@ -25,6 +24,82 @@ def __init__(
25
24
self ._local_range = local_range
26
25
self ._group_range = group_range
27
26
self ._index = index
27
+ self ._leader = False
28
+
29
+ def get_group_id (self , dim ):
30
+ """Returns the index of the work-group within the global nd-range for
31
+ specified dimension.
32
+
33
+ Since the work-items in a work-group have a defined position within the
34
+ global nd-range, the returned group id can be used along with the local
35
+ id to uniquely identify the work-item in the global nd-range.
36
+ """
37
+ if dim > len (self ._index ) - 1 :
38
+ raise ValueError (
39
+ "Dimension value is out of bounds for the group index"
40
+ )
41
+ return self ._index [dim ]
42
+
43
+ def get_group_linear_id (self ):
44
+ """Returns a linearized version of the work-group index."""
45
+ if len (self ._index ) == 1 :
46
+ return self ._index [0 ]
47
+ if len (self ._index ) == 2 :
48
+ return self ._index [0 ] * self ._group_range [1 ] + self ._index [1 ]
49
+ return (
50
+ (self ._index [0 ] * self ._group_range [1 ] * self ._group_range [2 ])
51
+ + (self ._index [1 ] * self ._group_range [2 ])
52
+ + (self ._index [2 ])
53
+ )
54
+
55
+ def get_group_range (self ):
56
+ """Returns a range representing the number of groups in the nd-range."""
57
+ return self ._group_range
58
+
59
+ def get_group_linear_range (self ):
60
+ """Return the total number of work-groups in the nd_range."""
61
+ num_wg = 1
62
+ for ext in self ._group_range :
63
+ num_wg *= ext
64
+
65
+ return num_wg
66
+
67
+ def get_local_range (self ):
68
+ """Returns a SYCL range representing all dimensions of the local
69
+ range. This local range may have been provided by the programmer, or
70
+ chosen by the SYCL runtime.
71
+ """
72
+ return self ._local_range
73
+
74
+ def get_local_linear_range (self ):
75
+ """Return the total number of work-items in the work-group."""
76
+ num_wi = 1
77
+ for ext in self ._local_range :
78
+ num_wi *= ext
79
+
80
+ return num_wi
81
+
82
+ @property
83
+ def leader (self ):
84
+ """Return true for exactly one work-item in the work-group, if the
85
+ calling work-item is the leader of the work-group, and false for all
86
+ other work-items in the work-group.
87
+
88
+ The leader of the work-group is determined during construction of the
89
+ work-group, and is invariant for the lifetime of the work-group. The
90
+ leader of the work-group is guaranteed to be the work-item with a
91
+ local id of 0.
92
+
93
+
94
+ Returns:
95
+ bool: If the work item is the designated leader of the
96
+ """
97
+ return self ._leader
98
+
99
+ @leader .setter
100
+ def leader (self , work_item_id ):
101
+ """Sets the leader attribute for the group."""
102
+ self ._leader = work_item_id
28
103
29
104
30
105
class Item :
@@ -45,7 +120,7 @@ def get_linear_id(self):
45
120
"""
46
121
if len (self ._extent ) == 1 :
47
122
return self ._index [0 ]
48
- if len (self ._extent ) == 1 :
123
+ if len (self ._extent ) == 2 :
49
124
return self ._index [0 ] * self ._extent [1 ] + self ._index [1 ]
50
125
return (
51
126
(self ._index [0 ] * self ._extent [1 ] * self ._extent [2 ])
@@ -90,6 +165,8 @@ def __init__(self, global_item: Item, local_item: Item, group: Group):
90
165
self ._global_item = global_item
91
166
self ._local_item = local_item
92
167
self ._group = group
168
+ if self .get_local_linear_id () == 0 :
169
+ self ._group .leader = True
93
170
94
171
def get_global_id (self , idx ):
95
172
"""Get the global id for a specific dimension.
0 commit comments