Skip to content

Commit 15b4a59

Browse files
author
Diptorup Deb
committed
Adds a kernel launcher function for mock kernels.
1 parent 73fbbd7 commit 15b4a59

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

numba_dpex/kernel_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .atomic_ref import AtomicRef
1313
from .barrier import group_barrier
1414
from .index_space_ids import Group, Item, NdItem
15+
from .launcher import call_kernel
1516
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
1617
from .ranges import NdRange, Range
1718

@@ -26,4 +27,5 @@
2627
"NdItem",
2728
"Item",
2829
"group_barrier",
30+
"call_kernel",
2931
]

numba_dpex/kernel_api/launcher.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Implementation of mock kernel launcher functions
6+
"""
7+
8+
from inspect import signature
9+
from itertools import product
10+
11+
from .index_space_ids import Group, Item, NdItem
12+
from .ranges import NdRange, Range
13+
14+
15+
def _range_kernel_launcher(kernel_fn, index_range, *kernel_args):
16+
"""Executes a function that mocks a range kernel.
17+
18+
Converts the range into a set of index tuple that represent an element in
19+
the iteration domain over which the kernel will be executed. Then the
20+
kernel is called sequentially over that set of indices after each index
21+
value is used to construct an Item object.
22+
23+
Args:
24+
kernel_fn : A callable function object
25+
index_range (numba_dpex.Range): An instance of a Range object
26+
27+
Raises:
28+
ValueError: If the number of passed in kernel arguments is not the
29+
number of function parameters subtracted by one. The first kernel
30+
argument is expected to be an Item object.
31+
"""
32+
33+
range_sets = [range(ir) for ir in index_range]
34+
index_tuples = list(product(*range_sets))
35+
36+
for idx in index_tuples:
37+
it = Item(extent=index_range, index=idx)
38+
39+
if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
40+
raise ValueError(
41+
"Required number of kernel function arguments do not "
42+
"match provided number of kernel args"
43+
)
44+
45+
kernel_fn(it, *kernel_args)
46+
47+
48+
def _ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args):
49+
"""Executes a function that mocks a range kernel.
50+
51+
Args:
52+
kernel_fn : A callable function object
53+
index_range (numba_dpex.NdRange): An instance of a NdRange object
54+
55+
Raises:
56+
ValueError: If the number of passed in kernel arguments is not the
57+
number of function parameters subtracted by one. The first kernel
58+
argument is expected to be an Item object.
59+
"""
60+
group_range = tuple(
61+
gr // lr
62+
for gr, lr in zip(index_range.global_range, index_range.local_range)
63+
)
64+
local_range_sets = [range(ir) for ir in index_range.local_range]
65+
group_range_sets = [range(gr) for gr in group_range]
66+
local_index_tuples = list(product(*local_range_sets))
67+
group_index_tuples = list(product(*group_range_sets))
68+
69+
# Loop over the groups (parallel loop)
70+
for gidx in group_index_tuples:
71+
# loop over work items in the group (parallel loop)
72+
for lidx in local_index_tuples:
73+
global_id = []
74+
# to calculate global indices
75+
for dim, gidx_val in enumerate(gidx):
76+
global_id.append(
77+
gidx_val * index_range.local_range[dim] + lidx[dim]
78+
)
79+
# Every NdItem has its own global Item, local Item and Group
80+
nditem = NdItem(
81+
global_item=Item(
82+
extent=index_range.global_range, index=global_id
83+
),
84+
local_item=Item(extent=index_range.local_range, index=lidx),
85+
group=Group(
86+
index_range.global_range,
87+
index_range.local_range,
88+
group_range,
89+
gidx,
90+
),
91+
)
92+
93+
if len(signature(kernel_fn).parameters) - len(kernel_args) != 1:
94+
raise ValueError(
95+
"Required number of kernel function arguments do not "
96+
"match provided number of kernel args"
97+
)
98+
99+
kernel_fn(nditem, *kernel_args)
100+
101+
102+
def call_kernel(kernel_fn, index_range, *kernel_args):
103+
"""Mocks the launching of a kernel function over either a Range or NdRange.
104+
105+
Args:
106+
kernel_fn : A callable function object
107+
index_range (numba_dpex.Range): An instance of a Range object
108+
109+
Raises:
110+
ValueError: If the first positional argument is not callable
111+
ValueError: If the second positional argument is not a Range or an
112+
Ndrange object
113+
"""
114+
if not callable(kernel_fn):
115+
raise ValueError(
116+
"Expected the first positional argument to be a function object"
117+
)
118+
if isinstance(index_range, Range):
119+
_range_kernel_launcher(kernel_fn, index_range, *kernel_args)
120+
elif isinstance(index_range, NdRange):
121+
_ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args)
122+
else:
123+
raise ValueError(
124+
"Expected second positional argument to be Range or NdRange object"
125+
)

numba_dpex/kernel_api/ranges.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
from collections.abc import Iterable
1010

11+
from numba_dpex.core.exceptions import (
12+
UnmatchedNumberOfRangeDimsError,
13+
UnsupportedGroupWorkItemSizeError,
14+
)
15+
1116

1217
class Range(tuple):
1318
"""A data structure to encapsulate a single kernel launch parameter.
@@ -168,6 +173,22 @@ def __init__(self, global_size, local_size):
168173
+ "must be of either type Range or Iterable of int's."
169174
)
170175

176+
if len(self._local_range) != len(self._global_range):
177+
raise UnmatchedNumberOfRangeDimsError(
178+
kernel_name="",
179+
global_ndims=len(self._global_range),
180+
local_ndims=len(self._local_range),
181+
)
182+
183+
for i, _ in enumerate(self._global_range):
184+
if self._global_range[i] % self._local_range[i] != 0:
185+
raise UnsupportedGroupWorkItemSizeError(
186+
kernel_name="",
187+
dim=i,
188+
work_groups=self._global_range[i],
189+
work_items=self._local_range[i],
190+
)
191+
171192
@property
172193
def global_range(self):
173194
"""Accessor for global_range.

0 commit comments

Comments
 (0)