@@ -35,6 +35,19 @@ def set_last_one_item(item: Item, a):
35
35
a [i ] = 1
36
36
37
37
38
+ @dpex_exp .kernel
39
+ def set_last_one_linear_item (item : Item , a ):
40
+ i = item .get_linear_range () - 1
41
+ a [i ] = 1
42
+
43
+
44
+ @dpex_exp .kernel
45
+ def set_last_one_linear_nd_item (nd_item : NdItem , a ):
46
+ i = nd_item .get_global_linear_range () - 1
47
+ a [0 ] = i
48
+ a [i ] = 1
49
+
50
+
38
51
@dpex_exp .kernel
39
52
def set_last_one_nd_item (item : NdItem , a ):
40
53
if item .get_global_id (0 ) == 0 :
@@ -43,6 +56,20 @@ def set_last_one_nd_item(item: NdItem, a):
43
56
a [i ] = 1
44
57
45
58
59
+ @dpex_exp .kernel
60
+ def set_last_group_one_linear_nd_item (nd_item : NdItem , a ):
61
+ i = nd_item .get_local_linear_range () - 1
62
+ a [0 ] = i
63
+ a [i ] = 1
64
+
65
+
66
+ @dpex_exp .kernel
67
+ def set_last_group_one_group_linear_nd_item (nd_item : NdItem , a ):
68
+ i = nd_item .get_group ().get_local_linear_range () - 1
69
+ a [0 ] = i
70
+ a [i ] = 1
71
+
72
+
46
73
@dpex_exp .kernel
47
74
def set_last_group_one_nd_item (item : NdItem , a ):
48
75
if item .get_global_id (0 ) == 0 :
@@ -99,6 +126,12 @@ def _get_group_range_driver(nditem: NdItem, a):
99
126
a [i ] = g .get_group_range (0 )
100
127
101
128
129
+ def _get_group_linear_range_driver (nditem : NdItem , a ):
130
+ i = nditem .get_global_linear_id ()
131
+ g = nditem .get_group ()
132
+ a [i ] = g .get_group_linear_range ()
133
+
134
+
102
135
def _get_group_local_range_driver (nditem : NdItem , a ):
103
136
i = nditem .get_global_id (0 )
104
137
g = nditem .get_group ()
@@ -122,11 +155,34 @@ def test_item_get_range():
122
155
assert np .array_equal (a .asnumpy (), want )
123
156
124
157
125
- def test_nd_item_get_global_range ():
158
+ @pytest .mark .parametrize (
159
+ "rng" ,
160
+ [dpex .Range (_SIZE ), dpex .Range (1 , _GROUP_SIZE , int (_SIZE / _GROUP_SIZE ))],
161
+ )
162
+ def test_item_get_linear_range (rng ):
126
163
a = dpnp .zeros (_SIZE , dtype = dpnp .float32 )
127
- dpex_exp .call_kernel (
128
- set_last_one_nd_item , dpex .NdRange ((a .size ,), (_GROUP_SIZE ,)), a
129
- )
164
+ dpex_exp .call_kernel (set_last_one_linear_item , rng , a )
165
+
166
+ want = np .zeros (a .size , dtype = np .float32 )
167
+ want [- 1 ] = 1
168
+
169
+ assert np .array_equal (a .asnumpy (), want )
170
+
171
+
172
+ @pytest .mark .parametrize (
173
+ "kernel,rng" ,
174
+ [
175
+ (set_last_one_nd_item , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
176
+ (set_last_one_linear_nd_item , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
177
+ (
178
+ set_last_one_linear_nd_item ,
179
+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
180
+ ),
181
+ ],
182
+ )
183
+ def test_nd_item_get_global_range (kernel , rng ):
184
+ a = dpnp .zeros (_SIZE , dtype = dpnp .float32 )
185
+ dpex_exp .call_kernel (kernel , rng , a )
130
186
131
187
want = np .zeros (a .size , dtype = np .float32 )
132
188
want [- 1 ] = 1
@@ -135,11 +191,31 @@ def test_nd_item_get_global_range():
135
191
assert np .array_equal (a .asnumpy (), want )
136
192
137
193
138
- def test_nd_item_get_local_range ():
194
+ @pytest .mark .parametrize (
195
+ "kernel,rng" ,
196
+ [
197
+ (set_last_group_one_nd_item , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
198
+ (
199
+ set_last_group_one_linear_nd_item ,
200
+ dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,)),
201
+ ),
202
+ (
203
+ set_last_group_one_linear_nd_item ,
204
+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
205
+ ),
206
+ (
207
+ set_last_group_one_group_linear_nd_item ,
208
+ dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,)),
209
+ ),
210
+ (
211
+ set_last_group_one_group_linear_nd_item ,
212
+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
213
+ ),
214
+ ],
215
+ )
216
+ def test_nd_item_get_local_range (kernel , rng ):
139
217
a = dpnp .zeros (_SIZE , dtype = dpnp .float32 )
140
- dpex_exp .call_kernel (
141
- set_last_group_one_nd_item , dpex .NdRange ((a .size ,), (_GROUP_SIZE ,)), a
142
- )
218
+ dpex_exp .call_kernel (kernel , rng , a )
143
219
144
220
want = np .zeros (a .size , dtype = np .float32 )
145
221
want [_GROUP_SIZE - 1 ] = 1
@@ -240,21 +316,32 @@ def test_get_group_id(driver, rng):
240
316
assert np .array_equal (ka .asnumpy (), expected )
241
317
242
318
243
- def test_get_group_range ():
244
- global_size = 100
245
- group_size = 20
246
- num_groups = global_size // group_size
319
+ @pytest .mark .parametrize (
320
+ "driver,rng" ,
321
+ [
322
+ (_get_group_range_driver , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
323
+ (
324
+ _get_group_linear_range_driver ,
325
+ dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,)),
326
+ ),
327
+ (
328
+ _get_group_linear_range_driver ,
329
+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
330
+ ),
331
+ ],
332
+ )
333
+ def test_get_group_range (driver , rng ):
334
+ num_groups = _SIZE // _GROUP_SIZE
247
335
248
- a = dpnp .empty (global_size , dtype = dpnp .int32 )
249
- ka = dpnp .empty (global_size , dtype = dpnp .int32 )
250
- expected = np .empty (global_size , dtype = np .int32 )
251
- ndrange = NdRange ((global_size ,), (group_size ,))
252
- dpex_exp .call_kernel (dpex_exp .kernel (_get_group_range_driver ), ndrange , a )
253
- kapi_call_kernel (_get_group_range_driver , ndrange , ka )
336
+ a = dpnp .empty (_SIZE , dtype = dpnp .int32 )
337
+ ka = dpnp .empty (_SIZE , dtype = dpnp .int32 )
338
+ expected = np .empty (_SIZE , dtype = np .int32 )
339
+ dpex_exp .call_kernel (dpex_exp .kernel (driver ), rng , a )
340
+ kapi_call_kernel (driver , rng , ka )
254
341
255
342
for gid in range (num_groups ):
256
- for lid in range (group_size ):
257
- expected [gid * group_size + lid ] = num_groups
343
+ for lid in range (_GROUP_SIZE ):
344
+ expected [gid * _GROUP_SIZE + lid ] = num_groups
258
345
259
346
assert np .array_equal (a .asnumpy (), expected )
260
347
assert np .array_equal (ka .asnumpy (), expected )
0 commit comments