|
11 | 11 |
|
12 | 12 | import numba_dpex as dpex
|
13 | 13 | import numba_dpex.experimental as dpex_exp
|
14 |
| -from numba_dpex.kernel_api import Item, NdItem |
| 14 | +from numba_dpex.kernel_api import Item, NdItem, NdRange |
| 15 | +from numba_dpex.kernel_api import call_kernel as kapi_call_kernel |
15 | 16 | from numba_dpex.tests._helper import skip_windows
|
16 | 17 |
|
17 | 18 | _SIZE = 16
|
@@ -63,6 +64,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
|
63 | 64 | a[i] = 1
|
64 | 65 |
|
65 | 66 |
|
| 67 | +def _get_group_id_driver(nditem: NdItem, a): |
| 68 | + i = nditem.get_global_id(0) |
| 69 | + g = nditem.get_group() |
| 70 | + a[i] = g.get_group_id(0) |
| 71 | + |
| 72 | + |
| 73 | +def _get_group_range_driver(nditem: NdItem, a): |
| 74 | + i = nditem.get_global_id(0) |
| 75 | + g = nditem.get_group() |
| 76 | + a[i] = g.get_group_range(0) |
| 77 | + |
| 78 | + |
| 79 | +def _get_group_local_range_driver(nditem: NdItem, a): |
| 80 | + i = nditem.get_global_id(0) |
| 81 | + g = nditem.get_group() |
| 82 | + a[i] = g.get_local_range(0) |
| 83 | + |
| 84 | + |
66 | 85 | def test_item_get_id():
|
67 | 86 | a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
|
68 | 87 | dpex_exp.call_kernel(set_ones_item, dpex.Range(a.size), a)
|
@@ -155,6 +174,74 @@ def test_no_item():
|
155 | 174 | )
|
156 | 175 |
|
157 | 176 |
|
| 177 | +# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 |
| 178 | +@skip_windows |
| 179 | +def test_get_group_id(): |
| 180 | + global_size = 100 |
| 181 | + group_size = 20 |
| 182 | + num_groups = global_size // group_size |
| 183 | + |
| 184 | + a = dpnp.empty(global_size, dtype=dpnp.int32) |
| 185 | + ka = dpnp.empty(global_size, dtype=dpnp.int32) |
| 186 | + expected = np.empty(global_size, dtype=np.int32) |
| 187 | + ndrange = NdRange((global_size,), (group_size,)) |
| 188 | + dpex_exp.call_kernel(dpex_exp.kernel(_get_group_id_driver), ndrange, a) |
| 189 | + kapi_call_kernel(_get_group_id_driver, ndrange, ka) |
| 190 | + |
| 191 | + for gid in range(num_groups): |
| 192 | + for lid in range(group_size): |
| 193 | + expected[gid * group_size + lid] = gid |
| 194 | + |
| 195 | + assert np.array_equal(a.asnumpy(), expected) |
| 196 | + assert np.array_equal(ka.asnumpy(), expected) |
| 197 | + |
| 198 | + |
| 199 | +# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 |
| 200 | +@skip_windows |
| 201 | +def test_get_group_range(): |
| 202 | + global_size = 100 |
| 203 | + group_size = 20 |
| 204 | + num_groups = global_size // group_size |
| 205 | + |
| 206 | + a = dpnp.empty(global_size, dtype=dpnp.int32) |
| 207 | + ka = dpnp.empty(global_size, dtype=dpnp.int32) |
| 208 | + expected = np.empty(global_size, dtype=np.int32) |
| 209 | + ndrange = NdRange((global_size,), (group_size,)) |
| 210 | + dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a) |
| 211 | + kapi_call_kernel(_get_group_range_driver, ndrange, ka) |
| 212 | + |
| 213 | + for gid in range(num_groups): |
| 214 | + for lid in range(group_size): |
| 215 | + expected[gid * group_size + lid] = num_groups |
| 216 | + |
| 217 | + assert np.array_equal(a.asnumpy(), expected) |
| 218 | + assert np.array_equal(ka.asnumpy(), expected) |
| 219 | + |
| 220 | + |
| 221 | +# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 |
| 222 | +@skip_windows |
| 223 | +def test_get_group_local_range(): |
| 224 | + global_size = 100 |
| 225 | + group_size = 20 |
| 226 | + num_groups = global_size // group_size |
| 227 | + |
| 228 | + a = dpnp.empty(global_size, dtype=dpnp.int32) |
| 229 | + ka = dpnp.empty(global_size, dtype=dpnp.int32) |
| 230 | + expected = np.empty(global_size, dtype=np.int32) |
| 231 | + ndrange = NdRange((global_size,), (group_size,)) |
| 232 | + dpex_exp.call_kernel( |
| 233 | + dpex_exp.kernel(_get_group_local_range_driver), ndrange, a |
| 234 | + ) |
| 235 | + kapi_call_kernel(_get_group_local_range_driver, ndrange, ka) |
| 236 | + |
| 237 | + for gid in range(num_groups): |
| 238 | + for lid in range(group_size): |
| 239 | + expected[gid * group_size + lid] = group_size |
| 240 | + |
| 241 | + assert np.array_equal(a.asnumpy(), expected) |
| 242 | + assert np.array_equal(ka.asnumpy(), expected) |
| 243 | + |
| 244 | + |
158 | 245 | I_SIZE, J_SIZE, K_SIZE = 2, 3, 4
|
159 | 246 |
|
160 | 247 |
|
|
0 commit comments