Skip to content

Commit 86d5d4e

Browse files
author
Diptorup Deb
committed
Simplification of exec. queue derivation.
- The removal of support for NumPy args in kernels helps simplify the logic for execution queue derivation. - The ComputeFollowsDataInferenceError was renamed to ExecutionQueueInferenceError, and the old ExecutionQueueInferenceError was removed.
1 parent 21d4085 commit 86d5d4e

File tree

4 files changed

+25
-169
lines changed

4 files changed

+25
-169
lines changed

numba_dpex/core/exceptions.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(self, kernel_name, dim, work_groups, work_items) -> None:
182182
super().__init__(self.message)
183183

184184

185-
class ComputeFollowsDataInferenceError(Exception):
185+
class ExecutionQueueInferenceError(Exception):
186186
"""Exception raised when an execution queue for a given array expression or
187187
a kernel function could not be deduced using the compute-follows-data
188188
programming model.
@@ -194,34 +194,20 @@ class ComputeFollowsDataInferenceError(Exception):
194194
which the array operands were allocated. Computation is required to occur
195195
on the same device where the arrays currently reside.
196196
197-
A ComputeFollowsDataInferenceError is raised when the execution queue using
197+
A ExecutionQueueInferenceError is raised when the execution queue using
198198
compute-follows-data rules could not be deduced. It may happen when arrays
199199
that have a device attribute such as ``dpctl.tensor.usm_ndarray`` are mixed
200200
with host arrays such as ``numpy.ndarray``. The error may also be raised if
201201
the array operands are allocated on different devices.
202202
203203
Args:
204204
kernel_name : Name of the kernel function for which the error occurred.
205-
ndarray_argnum_list: The list of ``numpy.ndarray`` arguments identified
206-
by the argument position that caused the error.
207205
usmarray_argnum_list: The list of ``dpctl.tensor.usm_ndarray`` arguments
208206
identified by the argument position that caused the error.
209207
"""
210208

211-
def __init__(
212-
self, kernel_name, ndarray_argnum_list=None, *, usmarray_argnum_list
213-
) -> None:
214-
if ndarray_argnum_list and usmarray_argnum_list:
215-
ndarray_args = ",".join([str(i) for i in ndarray_argnum_list])
216-
usmarray_args = ",".join([str(i) for i in usmarray_argnum_list])
217-
self.message = (
218-
f'Kernel "{kernel_name}" has arguments of both usm_ndarray and '
219-
"non-usm_ndarray types. Mixing of arguments of different "
220-
"array types is disallowed. "
221-
f"Arguments {ndarray_args} are non-usm arrays, "
222-
f"and arguments {usmarray_args} are usm arrays."
223-
)
224-
elif usmarray_argnum_list is not None:
209+
def __init__(self, kernel_name, *, usmarray_argnum_list) -> None:
210+
if usmarray_argnum_list is not None:
225211
usmarray_args = ",".join([str(i) for i in usmarray_argnum_list])
226212
self.message = (
227213
f'Execution queue for kernel "{kernel_name}" could '
@@ -232,32 +218,6 @@ def __init__(
232218
super().__init__(self.message)
233219

234220

235-
class ExecutionQueueInferenceError(Exception):
236-
"""Exception raised when an execution queue could not be deduced for NumPy
237-
ndarray kernel arguments.
238-
239-
Args:
240-
kernel_name (str): Name of kernel where the error was raised.
241-
242-
.. deprecated:: 0.19
243-
"""
244-
245-
def __init__(self, kernel_name) -> None:
246-
warn(
247-
"The ExecutionQueueInferenceError class is deprecated, and will "
248-
+ "be removed once support for NumPy ndarrays as kernel arguments "
249-
+ "is removed.",
250-
DeprecationWarning,
251-
stacklevel=2,
252-
)
253-
self.message = (
254-
f'Kernel "{kernel_name}" was called with NumPy ndarray arguments '
255-
"outside a dpctl.device_context. The execution queue to be used "
256-
"could not be deduced."
257-
)
258-
super().__init__(self.message)
259-
260-
261221
class UnsupportedBackendError(Exception):
262222
"""Exception raised when the target device is not supported by dpex.
263223

numba_dpex/core/kernel_interface/utils.py

Lines changed: 15 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from warnings import warn
65

7-
import dpctl
8-
from numba.core.types import Array as NpArrayType
9-
10-
from numba_dpex.core.exceptions import (
11-
ComputeFollowsDataInferenceError,
12-
ExecutionQueueInferenceError,
13-
)
6+
from numba_dpex.core.exceptions import ExecutionQueueInferenceError
147
from numba_dpex.core.types import USMNdArray
158

169

@@ -50,131 +43,34 @@ def chk_compute_follows_data_compliance(usm_array_arglist):
5043
def determine_kernel_launch_queue(args, argtypes, kernel_name):
5144
"""Determines the queue where the kernel is to be launched.
5245
53-
The execution queue is derived using the following algorithm. In future,
54-
support for ``numpy.ndarray`` and ``dpctl.device_context`` is to be
55-
removed and queue derivation will follows Python Array API's
56-
"compute follows data" logic.
57-
58-
Check if there are array arguments.
59-
True:
60-
Check if all array arguments are of type numpy.ndarray
61-
(numba.types.Array)
62-
True:
63-
Check if the kernel was invoked from within a
64-
dpctl.device_context.
65-
True:
66-
Provide a deprecation warning for device_context use and
67-
point to using dpctl.tensor.usm_ndarray or dpnp.ndarray
68-
69-
return dpctl.get_current_queue
70-
False:
71-
Raise ExecutionQueueInferenceError
72-
False:
73-
Check if all of the arrays are USMNdarray
74-
True:
75-
Check if execution queue could be inferred using
76-
compute follows data rules
77-
True:
78-
return the compute follows data inferred queue
79-
False:
80-
Raise ComputeFollowsDataInferenceError
81-
False:
82-
Raise ComputeFollowsDataInferenceError
83-
False:
84-
Check if the kernel was invoked from within a dpctl.device_context.
85-
True:
86-
Provide a deprecation warning for device_context use and
87-
point to using dpctl.tensor.usm_ndarray of dpnp.ndarray
88-
89-
return dpctl.get_current_queue
90-
False:
91-
Raise ExecutionQueueInferenceError
46+
The execution queue is derived following Python Array API's
47+
"compute follows data" programming model.
9248
9349
Args:
94-
args : A list of arguments passed to the kernel stored in the
95-
launcher.
9650
argtypes : The Numba inferred type for each argument.
51+
kernel_name : The name of the kernel function
9752
9853
Returns:
9954
A queue the common queue used to allocate the arrays. If no such
10055
queue exists, then raises an Exception.
10156
10257
Raises:
103-
ComputeFollowsDataInferenceError: If the queue could not be inferred
104-
using compute follows data rules.
105-
ExecutionQueueInferenceError: If the queue could not be inferred
106-
using the dpctl queue manager.
58+
ExecutionQueueInferenceError: If the queue could not be inferred.
10759
"""
10860

109-
# FIXME: The args parameter is not needed once numpy support is removed
110-
111-
# Needed as USMNdArray derives from Array
112-
array_argnums = [
113-
i
114-
for i, _ in enumerate(args)
115-
if isinstance(argtypes[i], NpArrayType)
116-
and not isinstance(argtypes[i], USMNdArray)
117-
]
11861
usmarray_argnums = [
11962
i for i, _ in enumerate(args) if isinstance(argtypes[i], USMNdArray)
12063
]
12164

122-
# if usm and non-usm array arguments are getting mixed, then the
123-
# execution queue cannot be inferred using compute follows data rules.
124-
if array_argnums and usmarray_argnums:
125-
raise ComputeFollowsDataInferenceError(
126-
array_argnums, usmarray_argnum_list=usmarray_argnums
65+
usm_array_args = [
66+
argtype for i, argtype in enumerate(argtypes) if i in usmarray_argnums
67+
]
68+
69+
queue = chk_compute_follows_data_compliance(usm_array_args)
70+
71+
if not queue:
72+
raise ExecutionQueueInferenceError(
73+
kernel_name, usmarray_argnum_list=usmarray_argnums
12774
)
128-
elif array_argnums and not usmarray_argnums:
129-
if dpctl.is_in_device_context():
130-
warn(
131-
"Support for dpctl.device_context to specify the "
132-
+ "execution queue is deprecated. "
133-
+ "Use dpctl.tensor.usm_ndarray based array "
134-
+ "containers instead. ",
135-
DeprecationWarning,
136-
stacklevel=2,
137-
)
138-
warn(
139-
"Support for NumPy ndarray objects as kernel arguments is "
140-
+ "deprecated. Use dpctl.tensor.usm_ndarray based array "
141-
+ "containers instead. ",
142-
DeprecationWarning,
143-
stacklevel=2,
144-
)
145-
return dpctl.get_current_queue()
146-
else:
147-
raise ExecutionQueueInferenceError(kernel_name)
148-
elif usmarray_argnums and not array_argnums:
149-
if dpctl.is_in_device_context():
150-
warn(
151-
"dpctl.device_context ignored as the kernel arguments "
152-
+ "are dpctl.tensor.usm_ndarray based array containers."
153-
)
154-
usm_array_args = [
155-
argtype
156-
for i, argtype in enumerate(argtypes)
157-
if i in usmarray_argnums
158-
]
159-
160-
queue = chk_compute_follows_data_compliance(usm_array_args)
16175

162-
if not queue:
163-
raise ComputeFollowsDataInferenceError(
164-
kernel_name, usmarray_argnum_list=usmarray_argnums
165-
)
166-
167-
return queue
168-
else:
169-
if dpctl.is_in_device_context():
170-
warn(
171-
"Support for dpctl.device_context to specify the "
172-
+ "execution queue is deprecated. "
173-
+ "Use dpctl.tensor.usm_ndarray based array "
174-
+ "containers instead. ",
175-
DeprecationWarning,
176-
stacklevel=2,
177-
)
178-
return dpctl.get_current_queue()
179-
else:
180-
raise ExecutionQueueInferenceError(kernel_name)
76+
return queue

numba_dpex/core/passes/parfor_legalize_cfd_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
get_parfor_params,
1313
)
1414

15-
from numba_dpex.core.exceptions import ComputeFollowsDataInferenceError
15+
from numba_dpex.core.exceptions import ExecutionQueueInferenceError
1616
from numba_dpex.core.parfors.parfor_lowerer import ParforLowerFactory
1717
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray
1818

@@ -27,7 +27,7 @@ class ParforLegalizeCFDPassImpl:
2727
__array_ufunc__ method of DpnpNdArray class. The pass fixes the LHS type by
2828
properly applying compute follows data programming model. The pass first
2929
checks if the right-hand-side (RHS) DpnpNdArray arguments are on the same
30-
device, else raising a ComputeFollowsDataInferenceError. Once the RHS has
30+
device, else raising a ExecutionQueueInferenceError. Once the RHS has
3131
been validated, the LHS type is updated.
3232
3333
The pass also updated the usm_type of the LHS based on a USM type
@@ -92,7 +92,7 @@ def _check_cfd_parfor_params(self, parfor, checklist):
9292
)
9393
# Check compute follows data on the dpnp arrays in checklist
9494
if len(deviceTypes) > 1:
95-
raise ComputeFollowsDataInferenceError(
95+
raise ExecutionQueueInferenceError(
9696
kernel_name=parfor.loc.short(),
9797
usmarray_argnum_list=[],
9898
)

numba_dpex/tests/core/passes/test_parfor_legalize_cfd_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytest
1414

1515
from numba_dpex import dpjit
16-
from numba_dpex.core.exceptions import ComputeFollowsDataInferenceError
16+
from numba_dpex.core.exceptions import ExecutionQueueInferenceError
1717
from numba_dpex.tests._helper import skip_no_opencl_gpu
1818

1919
shapes = [10, (2, 5)]
@@ -62,7 +62,7 @@ def test_parfor_legalize_cfd_pass_raise():
6262
a = dpnp.zeros(shape=10, device="cpu")
6363
b = dpnp.ones(shape=10, device="gpu")
6464

65-
with pytest.raises(ComputeFollowsDataInferenceError):
65+
with pytest.raises(ExecutionQueueInferenceError):
6666
func1(a, b)
6767

6868

@@ -78,5 +78,5 @@ def vecadd_prange(a, b):
7878
c[idx] = a[idx] + b[idx]
7979
return c
8080

81-
with pytest.raises(ComputeFollowsDataInferenceError):
81+
with pytest.raises(ExecutionQueueInferenceError):
8282
vecadd_prange(a, b)

0 commit comments

Comments
 (0)