@@ -42,14 +42,20 @@ def get_group_id(self, dim):
42
42
43
43
def get_group_linear_id (self ):
44
44
"""Returns a linearized version of the work-group index."""
45
- if len (self ._index ) == 1 :
46
- return self ._index [0 ]
47
- if len (self ._index ) == 2 :
48
- return self ._index [0 ] * self ._group_range [1 ] + self ._index [1 ]
45
+ if self .dimensions == 1 :
46
+ return self .get_group_id (0 )
47
+ if self .dimensions == 2 :
48
+ return self .get_group_id (0 ) * self .get_group_range (
49
+ 1
50
+ ) + self .get_group_id (1 )
49
51
return (
50
- (self ._index [0 ] * self ._group_range [1 ] * self ._group_range [2 ])
51
- + (self ._index [1 ] * self ._group_range [2 ])
52
- + (self ._index [2 ])
52
+ (
53
+ self .get_group_id (0 )
54
+ * self .get_group_range (1 )
55
+ * self .get_group_range (2 )
56
+ )
57
+ + (self .get_group_id (1 ) * self .get_group_range (2 ))
58
+ + (self .get_group_id (2 ))
53
59
)
54
60
55
61
def get_group_range (self , dim ):
@@ -61,8 +67,8 @@ def get_group_range(self, dim):
61
67
def get_group_linear_range (self ):
62
68
"""Return the total number of work-groups in the nd_range."""
63
69
num_wg = 1
64
- for ext in self ._group_range :
65
- num_wg *= ext
70
+ for i in range ( self .dimensions ) :
71
+ num_wg *= self . get_group_range ( i )
66
72
67
73
return num_wg
68
74
@@ -76,8 +82,8 @@ def get_local_range(self, dim):
76
82
def get_local_linear_range (self ):
77
83
"""Return the total number of work-items in the work-group."""
78
84
num_wi = 1
79
- for ext in self ._local_range :
80
- num_wi *= ext
85
+ for i in range ( self .dimensions ) :
86
+ num_wi *= self . get_local_range ( i )
81
87
82
88
return num_wi
83
89
@@ -128,14 +134,14 @@ def get_linear_id(self):
128
134
Returns:
129
135
int: The linear id.
130
136
"""
131
- if len ( self ._extent ) == 1 :
132
- return self ._index [ 0 ]
133
- if len ( self ._extent ) == 2 :
134
- return self ._index [ 0 ] * self ._extent [ 1 ] + self ._index [ 1 ]
137
+ if self .dimensions == 1 :
138
+ return self .get_id ( 0 )
139
+ if self .dimensions == 2 :
140
+ return self .get_id ( 0 ) * self .get_range ( 1 ) + self .get_id ( 1 )
135
141
return (
136
- (self ._index [ 0 ] * self ._extent [ 1 ] * self ._extent [ 2 ] )
137
- + (self ._index [ 1 ] * self ._extent [ 2 ] )
138
- + (self ._index [ 2 ] )
142
+ (self .get_id ( 0 ) * self .get_range ( 1 ) * self .get_range ( 2 ) )
143
+ + (self .get_id ( 1 ) * self .get_range ( 2 ) )
144
+ + (self .get_id ( 2 ) )
139
145
)
140
146
141
147
def get_id (self , idx ):
@@ -146,6 +152,14 @@ def get_id(self, idx):
146
152
"""
147
153
return self ._index [idx ]
148
154
155
+ def get_linear_range (self ):
156
+ """Return the total number of work-items in the work-group."""
157
+ num_wi = 1
158
+ for i in range (self .dimensions ):
159
+ num_wi *= self .get_range (i )
160
+
161
+ return num_wi
162
+
149
163
def get_range (self , idx ):
150
164
"""Get the range size for a specific dimension.
151
165
@@ -193,7 +207,24 @@ def get_global_linear_id(self):
193
207
Returns:
194
208
int: The global linear id.
195
209
"""
196
- return self ._global_item .get_linear_id ()
210
+ # Instead of calling self._global_item.get_linear_id(), the linearization
211
+ # logic is duplicated here so that the method can be JIT compiled by
212
+ # numba-dpex and works in both Python and Numba nopython modes.
213
+ if self .dimensions == 1 :
214
+ return self .get_global_id (0 )
215
+ if self .dimensions == 2 :
216
+ return self .get_global_id (0 ) * self .get_global_range (
217
+ 1
218
+ ) + self .get_global_id (1 )
219
+ return (
220
+ (
221
+ self .get_global_id (0 )
222
+ * self .get_global_range (1 )
223
+ * self .get_global_range (2 )
224
+ )
225
+ + (self .get_global_id (1 ) * self .get_global_range (2 ))
226
+ + (self .get_global_id (2 ))
227
+ )
197
228
198
229
def get_local_id (self , idx ):
199
230
"""Get the local id for a specific dimension.
@@ -210,7 +241,24 @@ def get_local_linear_id(self):
210
241
Returns:
211
242
int: The local linear id.
212
243
"""
213
- return self ._local_item .get_linear_id ()
244
+ # Instead of calling self._local_item.get_linear_id(), the linearization
245
+ # logic is duplicated here so that the method can be JIT compiled by
246
+ # numba-dpex and works in both Python and Numba nopython modes.
247
+ if self .dimensions == 1 :
248
+ return self .get_local_id (0 )
249
+ if self .dimensions == 2 :
250
+ return self .get_local_id (0 ) * self .get_local_range (
251
+ 1
252
+ ) + self .get_local_id (1 )
253
+ return (
254
+ (
255
+ self .get_local_id (0 )
256
+ * self .get_local_range (1 )
257
+ * self .get_local_range (2 )
258
+ )
259
+ + (self .get_local_id (1 ) * self .get_local_range (2 ))
260
+ + (self .get_local_id (2 ))
261
+ )
214
262
215
263
def get_global_range (self , idx ):
216
264
"""Get the global range size for a specific dimension.
@@ -228,6 +276,22 @@ def get_local_range(self, idx):
228
276
"""
229
277
return self ._local_item .get_range (idx )
230
278
279
+ def get_local_linear_range (self ):
280
+ """Return the total number of work-items in the work-group."""
281
+ num_wi = 1
282
+ for i in range (self .dimensions ):
283
+ num_wi *= self .get_local_range (i )
284
+
285
+ return num_wi
286
+
287
+ def get_global_linear_range (self ):
288
+ """Return the total number of work-items in the work-group."""
289
+ num_wi = 1
290
+ for i in range (self .dimensions ):
291
+ num_wi *= self .get_global_range (i )
292
+
293
+ return num_wi
294
+
231
295
def get_group (self ):
232
296
"""Returns the group.
233
297
0 commit comments