17
17
from numba_dpex import config , spirv_generator
18
18
from numba_dpex .core .descriptor import dpex_kernel_target
19
19
from numba_dpex .core .exceptions import (
20
- InvalidKernelLaunchArgsError ,
20
+ ExecutionQueueInferenceError ,
21
21
UnsupportedKernelArgumentError ,
22
22
)
23
- from numba_dpex .core .kernel_interface .indexers import NdRange , Range
24
23
from numba_dpex .core .pipelines import kernel_compiler
25
24
from numba_dpex .core .types import DpnpNdArray
26
25
27
- _KernelLauncherLowerResult = namedtuple (
28
- "_KernelLauncherLowerResult" ,
29
- ["sig" , "fndesc" , "library" , "call_helper" ],
30
- )
31
-
32
26
_KernelModule = namedtuple ("_KernelModule" , ["kernel_name" , "kernel_bitcode" ])
33
27
34
28
_KernelCompileResult = namedtuple (
38
32
39
33
40
34
class _KernelCompiler (_FunctionCompiler ):
35
+ """A special compiler class used to compile numba_dpex.kernel decorated
36
+ functions.
37
+ """
38
+
39
+ def _check_queue_equivalence_of_args (
40
+ self , py_func_name : str , args : [types .Type , ...]
41
+ ):
42
+ """Evaluates if all DpnpNdArray arguments passed to a kernel function
43
+ has the same DpctlSyclQueue type.
44
+
45
+ Args:
46
+ py_func_name (str): Name of the kernel that is being evaluated
47
+ args (types.Type, ...]): List of numba inferred types for each
48
+ argument passed to the kernel
49
+
50
+ Raises:
51
+ ExecutionQueueInferenceError: If all DpnpNdArray were not allocated
52
+ on the same dpctl.SyclQueue
53
+ ExecutionQueueInferenceError: If there were not DpnpNdArray
54
+ arguments passed to the kernel.
55
+ """
56
+ common_queue = None
57
+
58
+ for arg in args :
59
+ if isinstance (arg , DpnpNdArray ):
60
+ if common_queue is None :
61
+ common_queue = arg .queue
62
+ elif common_queue != arg .queue :
63
+ raise ExecutionQueueInferenceError (
64
+ kernel_name = py_func_name , usmarray_argnum_list = []
65
+ )
66
+
67
+ if common_queue is None :
68
+ raise ExecutionQueueInferenceError (
69
+ kernel_name = py_func_name , usmarray_argnum_list = None
70
+ )
71
+
41
72
def _compile_to_spirv (
42
73
self , kernel_library , kernel_fndesc , kernel_targetctx
43
74
):
@@ -156,9 +187,6 @@ def __init__(
156
187
targetoptions ["experimental" ] = True
157
188
158
189
self ._kernel_name = pyfunc .__name__
159
- self ._range = None
160
- self ._ndrange = None
161
-
162
190
self .typingctx = self .targetdescr .typing_context
163
191
self .targetctx = self .targetdescr .target_context
164
192
@@ -185,7 +213,7 @@ def __init__(
185
213
self ._cache = NullCache ()
186
214
compiler_class = self ._impl_kinds [impl_kind ]
187
215
self ._impl_kind = impl_kind
188
- self ._compiler = compiler_class (
216
+ self ._compiler : _KernelCompiler = compiler_class (
189
217
pyfunc , self .targetdescr , targetoptions , locals , pipeline_class
190
218
)
191
219
self ._cache_hits = Counter ()
@@ -250,6 +278,14 @@ def cb_llvm(dur):
250
278
# Use counter to track recursion compilation depth
251
279
with self ._compiling_counter :
252
280
args , return_type = sigutils .normalize_signature (sig )
281
+
282
+ try :
283
+ self ._compiler ._check_queue_equivalence_of_args (
284
+ self ._kernel_name , args
285
+ )
286
+ except ExecutionQueueInferenceError as eqie :
287
+ raise eqie
288
+
253
289
# Don't recompile if signature already exists
254
290
existing = self .overloads .get (tuple (args ))
255
291
if existing is not None :
@@ -283,40 +319,11 @@ def folded(args, kws):
283
319
return kcres .kernel_module
284
320
285
321
def __getitem__ (self , args ):
286
- """Square-bracket notation for configuring the global_range and
287
- local_range settings when launching a kernel on a SYCL queue.
288
-
289
- When a Python function decorated with the @kernel decorator,
290
- is invoked it creates a KernelLauncher object. Calling the
291
- KernelLauncher objects ``__getitem__`` function inturn clones the object
292
- and sets the ``global_range`` and optionally the ``local_range``
293
- attributes with the arguments passed to ``__getitem__``.
294
-
295
- Args:
296
- args (tuple): A tuple of tuples that specify the global and
297
- optionally the local range for the kernel execution. If the
298
- argument is a two-tuple of tuple, then it is assumed that both
299
- global and local range options are specified. The first entry is
300
- considered to be the global range and the second the local range.
301
-
302
- If only a single tuple value is provided, then the kernel is
303
- launched with only a global range and the local range configuration
304
- is decided by the SYCL runtime.
305
-
306
- Returns:
307
- KernelLauncher: A clone of the KernelLauncher object, but with the
308
- global_range and local_range attributes initialized.
322
+ """Square-bracket notation for configuring launch arguments is not
323
+ supported.
309
324
"""
310
325
311
- if isinstance (args , Range ):
312
- self ._range = args
313
- elif isinstance (args , NdRange ):
314
- self ._ndrange = args
315
- else :
316
- # FIXME: Improve error message
317
- raise InvalidKernelLaunchArgsError (kernel_name = self ._kernel_name )
318
-
319
- return self
326
+ raise NotImplementedError
320
327
321
328
def __call__ (self , * args , ** kw_args ):
322
329
"""Functor to launch a kernel."""
0 commit comments