Skip to content

Commit b06c7b3

Browse files
author
Diptorup Deb
authored
Merge pull request #1326 from IntelPython/feature/group_index_space_id_class
Extends mock kernel API by more group index functions and adds launchers
2 parents f3113f6 + ad91f05 commit b06c7b3

File tree

7 files changed

+334
-2
lines changed

7 files changed

+334
-2
lines changed

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .atomic_ref import AtomicRef
1313
from .barrier import group_barrier
1414
from .index_space_ids import Group, Item, NdItem
15+
from .launcher import call_kernel
1516
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
1617
from .ranges import NdRange, Range
1718

@@ -26,4 +27,5 @@
2627
"NdItem",
2728
"Item",
2829
"group_barrier",
30+
"call_kernel",
2931
]

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .ranges import Range
1111

1212

13-
# pylint: disable=too-few-public-methods
1413
class Group:
1514
"""Analogue to the ``sycl::group`` type."""
1615

@@ -25,6 +24,82 @@ def __init__(
2524
self._local_range = local_range
2625
self._group_range = group_range
2726
self._index = index
27+
self._leader = False
28+
29+
def get_group_id(self, dim):
30+
"""Returns the index of the work-group within the global nd-range for
31+
specified dimension.
32+
33+
Since the work-items in a work-group have a defined position within the
34+
global nd-range, the returned group id can be used along with the local
35+
id to uniquely identify the work-item in the global nd-range.
36+
"""
37+
if dim > len(self._index) - 1:
38+
raise ValueError(
39+
"Dimension value is out of bounds for the group index"
40+
)
41+
return self._index[dim]
42+
43+
def get_group_linear_id(self):
44+
"""Returns a linearized version of the work-group index."""
45+
if len(self._index) == 1:
46+
return self._index[0]
47+
if len(self._index) == 2:
48+
return self._index[0] * self._group_range[1] + self._index[1]
49+
return (
50+
(self._index[0] * self._group_range[1] * self._group_range[2])
51+
+ (self._index[1] * self._group_range[2])
52+
+ (self._index[2])
53+
)
54+
55+
def get_group_range(self):
56+
"""Returns a range representing the number of groups in the nd-range."""
57+
return self._group_range
58+
59+
def get_group_linear_range(self):
60+
"""Return the total number of work-groups in the nd_range."""
61+
num_wg = 1
62+
for ext in self._group_range:
63+
num_wg *= ext
64+
65+
return num_wg
66+
67+
def get_local_range(self):
68+
"""Returns a SYCL range representing all dimensions of the local
69+
range. This local range may have been provided by the programmer, or
70+
chosen by the SYCL runtime.
71+
"""
72+
return self._local_range
73+
74+
def get_local_linear_range(self):
75+
"""Return the total number of work-items in the work-group."""
76+
num_wi = 1
77+
for ext in self._local_range:
78+
num_wi *= ext
79+
80+
return num_wi
81+
82+
@property
83+
def leader(self):
84+
"""Return true for exactly one work-item in the work-group, if the
85+
calling work-item is the leader of the work-group, and false for all
86+
other work-items in the work-group.
87+
88+
The leader of the work-group is determined during construction of the
89+
work-group, and is invariant for the lifetime of the work-group. The
90+
leader of the work-group is guaranteed to be the work-item with a
91+
local id of 0.
92+
93+
94+
Returns:
95+
bool: If the work item is the designated leader of the
96+
"""
97+
return self._leader
98+
99+
@leader.setter
100+
def leader(self, work_item_id):
101+
"""Sets the leader attribute for the group."""
102+
self._leader = work_item_id
28103

29104

30105
class Item:
@@ -45,7 +120,7 @@ def get_linear_id(self):
45120
"""
46121
if len(self._extent) == 1:
47122
return self._index[0]
48-
if len(self._extent) == 1:
123+
if len(self._extent) == 2:
49124
return self._index[0] * self._extent[1] + self._index[1]
50125
return (
51126
(self._index[0] * self._extent[1] * self._extent[2])
@@ -90,6 +165,8 @@ def __init__(self, global_item: Item, local_item: Item, group: Group):
90165
self._global_item = global_item
91166
self._local_item = local_item
92167
self._group = group
168+
if self.get_local_linear_id() == 0:
169+
self._group.leader = True
93170

94171
def get_global_id(self, idx):
95172
"""Get the global id for a specific dimension.

numba_dpex/kernel_api/launcher.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Implementation of mock kernel launcher functions
6+
"""
7+
8+
from inspect import signature
9+
from itertools import product
10+
11+
from .index_space_ids import Group, Item, NdItem
12+
from .ranges import NdRange, Range
13+
14+
15+
def _range_kernel_launcher(kernel_fn, index_range, *kernel_args):
16+
"""Executes a function that mocks a range kernel.
17+
18+
Converts the range into a set of index tuple that represent an element in
19+
the iteration domain over which the kernel will be executed. Then the
20+
kernel is called sequentially over that set of indices after each index
21+
value is used to construct an Item object.
22+
23+
Args:
24+
kernel_fn : A callable function object
25+
index_range (numba_dpex.Range): An instance of a Range object
26+
27+
Raises:
28+
ValueError: If the number of passed in kernel arguments is not the
29+
number of function parameters subtracted by one. The first kernel
30+
argument is expected to be an Item object.
31+
"""
32+
33+
range_sets = [range(ir) for ir in index_range]
34+
index_tuples = list(product(*range_sets))
35+
36+
for idx in index_tuples:
37+
it = Item(extent=index_range, index=idx)
38+
39+
if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
40+
raise ValueError(
41+
"Required number of kernel function arguments do not "
42+
"match provided number of kernel args"
43+
)
44+
45+
kernel_fn(it, *kernel_args)
46+
47+
48+
def _ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args):
49+
"""Executes a function that mocks a range kernel.
50+
51+
Args:
52+
kernel_fn : A callable function object
53+
index_range (numba_dpex.NdRange): An instance of a NdRange object
54+
55+
Raises:
56+
ValueError: If the number of passed in kernel arguments is not the
57+
number of function parameters subtracted by one. The first kernel
58+
argument is expected to be an Item object.
59+
"""
60+
group_range = tuple(
61+
gr // lr
62+
for gr, lr in zip(index_range.global_range, index_range.local_range)
63+
)
64+
local_range_sets = [range(ir) for ir in index_range.local_range]
65+
group_range_sets = [range(gr) for gr in group_range]
66+
local_index_tuples = list(product(*local_range_sets))
67+
group_index_tuples = list(product(*group_range_sets))
68+
69+
# Loop over the groups (parallel loop)
70+
for gidx in group_index_tuples:
71+
# loop over work items in the group (parallel loop)
72+
for lidx in local_index_tuples:
73+
global_id = []
74+
# to calculate global indices
75+
for dim, gidx_val in enumerate(gidx):
76+
global_id.append(
77+
gidx_val * index_range.local_range[dim] + lidx[dim]
78+
)
79+
# Every NdItem has its own global Item, local Item and Group
80+
nditem = NdItem(
81+
global_item=Item(
82+
extent=index_range.global_range, index=global_id
83+
),
84+
local_item=Item(extent=index_range.local_range, index=lidx),
85+
group=Group(
86+
index_range.global_range,
87+
index_range.local_range,
88+
group_range,
89+
gidx,
90+
),
91+
)
92+
93+
if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
94+
raise ValueError(
95+
"Required number of kernel function arguments do not "
96+
"match provided number of kernel args"
97+
)
98+
99+
kernel_fn(nditem, *kernel_args)
100+
101+
102+
def call_kernel(kernel_fn, index_range, *kernel_args):
103+
"""Mocks the launching of a kernel function over either a Range or NdRange.
104+
105+
Args:
106+
kernel_fn : A callable function object
107+
index_range (numba_dpex.Range): An instance of a Range object
108+
109+
Raises:
110+
ValueError: If the first positional argument is not callable
111+
ValueError: If the second positional argument is not a Range or an
112+
Ndrange object
113+
"""
114+
if not callable(kernel_fn):
115+
raise ValueError(
116+
"Expected the first positional argument to be a function object"
117+
)
118+
if isinstance(index_range, Range):
119+
_range_kernel_launcher(kernel_fn, index_range, *kernel_args)
120+
elif isinstance(index_range, NdRange):
121+
_ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args)
122+
else:
123+
raise ValueError(
124+
"Expected second positional argument to be Range or NdRange object"
125+
)

numba_dpex/kernel_api/ranges.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
from collections.abc import Iterable
1010

11+
from numba_dpex.core.exceptions import (
12+
UnmatchedNumberOfRangeDimsError,
13+
UnsupportedGroupWorkItemSizeError,
14+
)
15+
1116

1217
class Range(tuple):
1318
"""A data structure to encapsulate a single kernel launch parameter.
@@ -168,6 +173,22 @@ def __init__(self, global_size, local_size):
168173
+ "must be of either type Range or Iterable of int's."
169174
)
170175

176+
if len(self._local_range) != len(self._global_range):
177+
raise UnmatchedNumberOfRangeDimsError(
178+
kernel_name="",
179+
global_ndims=len(self._global_range),
180+
local_ndims=len(self._local_range),
181+
)
182+
183+
for i, _ in enumerate(self._global_range):
184+
if self._global_range[i] % self._local_range[i] != 0:
185+
raise UnsupportedGroupWorkItemSizeError(
186+
kernel_name="",
187+
dim=i,
188+
work_groups=self._global_range[i],
189+
work_items=self._local_range[i],
190+
)
191+
171192
@property
172193
def global_range(self):
173194
"""Accessor for global_range.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import numpy
6+
7+
from numba_dpex import kernel_api as kapi
8+
9+
10+
def test_range_kernel_call1D():
11+
def vecadd(item: kapi.NdItem, a, b, c):
12+
idx = item.get_global_id(0)
13+
c[idx] = a[idx] + b[idx]
14+
15+
a = numpy.ones(100)
16+
b = numpy.ones(100)
17+
c = numpy.empty(100)
18+
19+
kapi.call_kernel(vecadd, kapi.NdRange((100,), (20,)), a, b, c)
20+
21+
assert numpy.allclose(c, a + b)
22+
23+
24+
def test_range_kernel_call2D():
25+
def vecadd(item: kapi.NdItem, a, b, c):
26+
idx = item.get_global_id(0)
27+
jdx = item.get_global_id(1)
28+
c[idx, jdx] = a[idx, jdx] + b[idx, jdx]
29+
30+
a = numpy.ones((10, 10))
31+
b = numpy.ones((10, 10))
32+
c = numpy.empty((10, 10))
33+
34+
kapi.call_kernel(vecadd, kapi.NdRange((10, 10), (2, 2)), a, b, c)
35+
36+
assert numpy.allclose(c, a + b)
37+
38+
39+
def test_range_kernel_call3D():
40+
def vecadd(item: kapi.Item, a, b, c):
41+
idx = item.get_global_id(0)
42+
jdx = item.get_global_id(1)
43+
kdx = item.get_global_id(2)
44+
c[idx, jdx, kdx] = a[idx, jdx, kdx] + b[idx, jdx, kdx]
45+
46+
a = numpy.ones((8, 8, 8))
47+
b = numpy.ones((8, 8, 8))
48+
c = numpy.empty((8, 8, 8))
49+
50+
kapi.call_kernel(vecadd, kapi.NdRange((8, 8, 8), (2, 2, 2)), a, b, c)
51+
52+
assert numpy.allclose(c, a + b)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import numpy
6+
7+
from numba_dpex import kernel_api as kapi
8+
9+
10+
def test_range_kernel_call1D():
11+
def vecadd(item: kapi.Item, a, b, c):
12+
idx = item.get_id(0)
13+
c[idx] = a[idx] + b[idx]
14+
15+
a = numpy.ones(100)
16+
b = numpy.ones(100)
17+
c = numpy.empty(100)
18+
19+
kapi.call_kernel(vecadd, kapi.Range(100), a, b, c)
20+
21+
assert numpy.allclose(c, a + b)
22+
23+
24+
def test_range_kernel_call2D():
25+
def vecadd(item: kapi.Item, a, b, c):
26+
idx = item.get_id(0)
27+
jdx = item.get_id(1)
28+
c[idx, jdx] = a[idx, jdx] + b[idx, jdx]
29+
30+
a = numpy.ones((10, 10))
31+
b = numpy.ones((10, 10))
32+
c = numpy.empty((10, 10))
33+
34+
kapi.call_kernel(vecadd, kapi.Range(10, 10), a, b, c)
35+
36+
assert numpy.allclose(c, a + b)
37+
38+
39+
def test_range_kernel_call3D():
40+
def vecadd(item: kapi.Item, a, b, c):
41+
idx = item.get_id(0)
42+
jdx = item.get_id(1)
43+
kdx = item.get_id(2)
44+
c[idx, jdx, kdx] = a[idx, jdx, kdx] + b[idx, jdx, kdx]
45+
46+
a = numpy.ones((5, 5, 5))
47+
b = numpy.ones((5, 5, 5))
48+
c = numpy.empty((5, 5, 5))
49+
50+
kapi.call_kernel(vecadd, kapi.Range(5, 5, 5), a, b, c)
51+
52+
assert numpy.allclose(c, a + b)

0 commit comments

Comments
 (0)