Skip to content

Commit a62359a

Browse files
author
Diptorup Deb
committed
Adds compute follows data check to KernelDispatcher
- Added checks for compute follows data compliance to kernel compilation. - Removed support for __get_item__ in KernelDispatcher - Address review comments.
1 parent 79e503b commit a62359a

File tree

3 files changed

+146
-43
lines changed

3 files changed

+146
-43
lines changed

numba_dpex/core/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,12 @@ def __init__(self, kernel_name, *, usmarray_argnum_list) -> None:
215215
f"usm_ndarray arguments {usmarray_args} were not allocated "
216216
"on the same queue."
217217
)
218+
else:
219+
self.message = (
220+
f'Execution queue for kernel "{kernel_name}" could '
221+
"be deduced using compute follows data programming model. The "
222+
"kernel has no USMNdArray argument."
223+
)
218224
super().__init__(self.message)
219225

220226

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,12 @@
1717
from numba_dpex import config, spirv_generator
1818
from numba_dpex.core.descriptor import dpex_kernel_target
1919
from numba_dpex.core.exceptions import (
20-
InvalidKernelLaunchArgsError,
20+
ExecutionQueueInferenceError,
2121
UnsupportedKernelArgumentError,
2222
)
23-
from numba_dpex.core.kernel_interface.indexers import NdRange, Range
2423
from numba_dpex.core.pipelines import kernel_compiler
2524
from numba_dpex.core.types import DpnpNdArray
2625

27-
_KernelLauncherLowerResult = namedtuple(
28-
"_KernelLauncherLowerResult",
29-
["sig", "fndesc", "library", "call_helper"],
30-
)
31-
3226
_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])
3327

3428
_KernelCompileResult = namedtuple(
@@ -38,6 +32,43 @@
3832

3933

4034
class _KernelCompiler(_FunctionCompiler):
35+
"""A special compiler class used to compile numba_dpex.kernel decorated
36+
functions.
37+
"""
38+
39+
def _check_queue_equivalence_of_args(
40+
self, py_func_name: str, args: [types.Type, ...]
41+
):
42+
"""Evaluates if all DpnpNdArray arguments passed to a kernel function
43+
has the same DpctlSyclQueue type.
44+
45+
Args:
46+
py_func_name (str): Name of the kernel that is being evaluated
47+
args (types.Type, ...]): List of numba inferred types for each
48+
argument passed to the kernel
49+
50+
Raises:
51+
ExecutionQueueInferenceError: If all DpnpNdArray were not allocated
52+
on the same dpctl.SyclQueue
53+
ExecutionQueueInferenceError: If there were not DpnpNdArray
54+
arguments passed to the kernel.
55+
"""
56+
common_queue = None
57+
58+
for arg in args:
59+
if isinstance(arg, DpnpNdArray):
60+
if common_queue is None:
61+
common_queue = arg.queue
62+
elif common_queue != arg.queue:
63+
raise ExecutionQueueInferenceError(
64+
kernel_name=py_func_name, usmarray_argnum_list=[]
65+
)
66+
67+
if common_queue is None:
68+
raise ExecutionQueueInferenceError(
69+
kernel_name=py_func_name, usmarray_argnum_list=None
70+
)
71+
4172
def _compile_to_spirv(
4273
self, kernel_library, kernel_fndesc, kernel_targetctx
4374
):
@@ -156,9 +187,6 @@ def __init__(
156187
targetoptions["experimental"] = True
157188

158189
self._kernel_name = pyfunc.__name__
159-
self._range = None
160-
self._ndrange = None
161-
162190
self.typingctx = self.targetdescr.typing_context
163191
self.targetctx = self.targetdescr.target_context
164192

@@ -185,7 +213,7 @@ def __init__(
185213
self._cache = NullCache()
186214
compiler_class = self._impl_kinds[impl_kind]
187215
self._impl_kind = impl_kind
188-
self._compiler = compiler_class(
216+
self._compiler: _KernelCompiler = compiler_class(
189217
pyfunc, self.targetdescr, targetoptions, locals, pipeline_class
190218
)
191219
self._cache_hits = Counter()
@@ -250,6 +278,14 @@ def cb_llvm(dur):
250278
# Use counter to track recursion compilation depth
251279
with self._compiling_counter:
252280
args, return_type = sigutils.normalize_signature(sig)
281+
282+
try:
283+
self._compiler._check_queue_equivalence_of_args(
284+
self._kernel_name, args
285+
)
286+
except ExecutionQueueInferenceError as eqie:
287+
raise eqie
288+
253289
# Don't recompile if signature already exists
254290
existing = self.overloads.get(tuple(args))
255291
if existing is not None:
@@ -283,40 +319,11 @@ def folded(args, kws):
283319
return kcres.kernel_module
284320

285321
def __getitem__(self, args):
286-
"""Square-bracket notation for configuring the global_range and
287-
local_range settings when launching a kernel on a SYCL queue.
288-
289-
When a Python function decorated with the @kernel decorator,
290-
is invoked it creates a KernelLauncher object. Calling the
291-
KernelLauncher objects ``__getitem__`` function inturn clones the object
292-
and sets the ``global_range`` and optionally the ``local_range``
293-
attributes with the arguments passed to ``__getitem__``.
294-
295-
Args:
296-
args (tuple): A tuple of tuples that specify the global and
297-
optionally the local range for the kernel execution. If the
298-
argument is a two-tuple of tuple, then it is assumed that both
299-
global and local range options are specified. The first entry is
300-
considered to be the global range and the second the local range.
301-
302-
If only a single tuple value is provided, then the kernel is
303-
launched with only a global range and the local range configuration
304-
is decided by the SYCL runtime.
305-
306-
Returns:
307-
KernelLauncher: A clone of the KernelLauncher object, but with the
308-
global_range and local_range attributes initialized.
322+
"""Square-bracket notation for configuring launch arguments is not
323+
supported.
309324
"""
310325

311-
if isinstance(args, Range):
312-
self._range = args
313-
elif isinstance(args, NdRange):
314-
self._ndrange = args
315-
else:
316-
# FIXME: Improve error message
317-
raise InvalidKernelLaunchArgsError(kernel_name=self._kernel_name)
318-
319-
return self
326+
raise NotImplementedError
320327

321328
def __call__(self, *args, **kw_args):
322329
"""Functor to launch a kernel."""
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
import dpctl
7+
import dpnp
8+
import pytest
9+
10+
import numba_dpex.experimental as exp_dpex
11+
from numba_dpex import Range
12+
from numba_dpex.core.exceptions import ExecutionQueueInferenceError
13+
14+
15+
@exp_dpex.kernel(
16+
release_gil=False,
17+
no_compile=True,
18+
no_cpython_wrapper=True,
19+
no_cfunc_wrapper=True,
20+
)
21+
def add(a, b, c):
22+
c[0] = b[0] + a[0]
23+
24+
25+
def test_successful_execution_queue_inference():
26+
"""
27+
Tests if KernelDispatcher successfully infers the execution queue for the
28+
kernel.
29+
"""
30+
31+
q = dpctl.SyclQueue()
32+
a = dpnp.ones(100, sycl_queue=q)
33+
b = dpnp.ones_like(a, sycl_queue=q)
34+
c = dpnp.zeros_like(a, sycl_queue=q)
35+
r = Range(100)
36+
37+
# FIXME: This test fails unexpectedly if the NUMBA_CAPTURED_ERRORS is set
38+
# to "new_style"
39+
try:
40+
exp_dpex.call_kernel(add, r, a, b, c)
41+
except:
42+
pytest.fail("Unexpected error when calling kernel")
43+
44+
assert c[0] == b[0] + a[0]
45+
46+
47+
def test_execution_queue_inference_error():
48+
"""
49+
Tests if KernelDispatcher successfully raised ExecutionQueueInferenceError
50+
when dpnp.ndarray arguments do not share the same dpctl.SyclQueue
51+
instance.
52+
"""
53+
54+
q1 = dpctl.SyclQueue()
55+
q2 = dpctl.SyclQueue()
56+
a = dpnp.ones(100, sycl_queue=q1)
57+
b = dpnp.ones_like(a, sycl_queue=q2)
58+
c = dpnp.zeros_like(a, sycl_queue=q1)
59+
r = Range(100)
60+
61+
from numba.core import config
62+
63+
current_captured_error_style = config.CAPTURED_ERRORS
64+
config.CAPTURED_ERRORS = "new_style"
65+
66+
with pytest.raises(ExecutionQueueInferenceError):
67+
exp_dpex.call_kernel(add, r, a, b, c)
68+
69+
config.CAPTURED_ERRORS = current_captured_error_style
70+
71+
72+
def test_error_when_no_array_args():
73+
"""
74+
Tests if KernelDispatcher successfully raised ExecutionQueueInferenceError
75+
when no dpnp.ndarray arguments were passed to a kernel.
76+
"""
77+
a = 1
78+
b = 2
79+
c = 3
80+
r = Range(100)
81+
82+
from numba.core import config
83+
84+
current_captured_error_style = config.CAPTURED_ERRORS
85+
config.CAPTURED_ERRORS = "new_style"
86+
87+
with pytest.raises(ExecutionQueueInferenceError):
88+
exp_dpex.call_kernel(add, r, a, b, c)
89+
90+
config.CAPTURED_ERRORS = current_captured_error_style

0 commit comments

Comments
 (0)