|
| 1 | +# SPDX-FileCopyrightText: 2023 Intel Corporation |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import inspect |
| 6 | + |
| 7 | +from numba.core import sigutils |
| 8 | + |
| 9 | +from .kernel_dispatcher import KernelDispatcher |
| 10 | + |
| 11 | + |
| 12 | +def kernel(func_or_sig=None, debug=False, cache=False, **options): |
| 13 | + """A decorator to define a kernel function. |
| 14 | +
|
| 15 | + A kernel function is conceptually equivalent to a SYCL kernel function, and |
| 16 | + gets compiled into either an OpenCL or a LevelZero SPIR-V binary kernel. |
| 17 | + A kernel decorated Python function has the following restrictions: |
| 18 | +
|
| 19 | + * The function can not return any value. |
| 20 | + * All array arguments passed to a kernel should adhere to compute |
| 21 | + follows data programming model. |
| 22 | + """ |
| 23 | + # FIXME: The options need to be evaluated and checked here like it is |
| 24 | + # done in numba.core.decorators.jit |
| 25 | + |
| 26 | + def _kernel_dispatcher(pyfunc, sigs=None): |
| 27 | + return KernelDispatcher( |
| 28 | + pyfunc=pyfunc, |
| 29 | + debug_flags=debug, |
| 30 | + enable_cache=cache, |
| 31 | + specialization_sigs=sigs, |
| 32 | + targetoptions=options, |
| 33 | + ) |
| 34 | + |
| 35 | + if func_or_sig is None: |
| 36 | + return _kernel_dispatcher |
| 37 | + elif isinstance(func_or_sig, str): |
| 38 | + raise NotImplementedError( |
| 39 | + "Specifying signatures as string is not yet supported by numba-dpex" |
| 40 | + ) |
| 41 | + elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): |
| 42 | + # String signatures are not supported as passing usm_ndarray type as |
| 43 | + # a string is not possible. Numba's sigutils relies on the type being |
| 44 | + # available in Numba's `types.__dict__` and dpex types are not |
| 45 | + # registered there yet. |
| 46 | + if isinstance(func_or_sig, list): |
| 47 | + for sig in func_or_sig: |
| 48 | + if isinstance(sig, str): |
| 49 | + raise NotImplementedError( |
| 50 | + "Specifying signatures as string is not yet supported " |
| 51 | + "by numba-dpex" |
| 52 | + ) |
| 53 | + # Specialized signatures can either be a single signature or a list. |
| 54 | + # In case only one signature is provided convert it to a list |
| 55 | + if not isinstance(func_or_sig, list): |
| 56 | + func_or_sig = [func_or_sig] |
| 57 | + |
| 58 | + def _specialized_kernel_dispatcher(pyfunc): |
| 59 | + return KernelDispatcher( |
| 60 | + pyfunc=pyfunc, |
| 61 | + debug_flags=debug, |
| 62 | + enable_cache=cache, |
| 63 | + specialization_sigs=func_or_sig, |
| 64 | + ) |
| 65 | + |
| 66 | + return _specialized_kernel_dispatcher |
| 67 | + else: |
| 68 | + func = func_or_sig |
| 69 | + if not inspect.isfunction(func): |
| 70 | + raise ValueError( |
| 71 | + "Argument passed to the kernel decorator is neither a " |
| 72 | + "function object, nor a signature. If you are trying to " |
| 73 | + "specialize the kernel that takes a single argument, specify " |
| 74 | + "the return type as void explicitly." |
| 75 | + ) |
| 76 | + return _kernel_dispatcher(func) |
0 commit comments