Skip to content

Commit 4a1b3dc

Browse files
author
Diptorup Deb
committed
Adds a compilation mode target option and device_func decorator.
- A new target option was added for the DpexKernelTarget target to compile functions using the experimental KernelDispatcher differently based on whether they are "kernels" or "device functions". kernels have the spir_kernel calling convention, cannot return a value, enforce execution queue equivalence, and are always compiled down to device IR (SPIR-V). device functions have the spir_func calling convention, do not have the same restrictions on return value and input arguments and are only compiled to LLVM bitcode. - A device_func decorator was added to experimental module. The new decorator is roughly equivalent to numba_dpex.func but uses the new KernelDispatcher and the compilation mode of device function. The `device_func` decorator is registered to compile overloads in DpexExpkernelTarget. - In the kernel compilation mode the final LLVM module is now "finaliozed" before conversion to SPIR-V. During finalization all overload calls are linked into the main (kernel) module and optionally inlined.
1 parent a62ff1a commit 4a1b3dc

File tree

4 files changed

+118
-13
lines changed

4 files changed

+118
-13
lines changed

numba_dpex/core/descriptor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .targets.dpjit_target import DPEX_TARGET_NAME, DpexTargetContext
1212
from .targets.kernel_target import (
1313
DPEX_KERNEL_TARGET_NAME,
14+
CompilationMode,
1415
DpexKernelTargetContext,
1516
DpexKernelTypingContext,
1617
)
@@ -40,13 +41,17 @@ class DpexTargetOptions(CPUTargetOptions):
4041
release_gil = _option_mapping("release_gil")
4142
no_compile = _option_mapping("no_compile")
4243
use_mlir = _option_mapping("use_mlir")
44+
_compilation_mode = _option_mapping("_compilation_mode")
4345

4446
def finalize(self, flags, options):
4547
super().finalize(flags, options)
4648
_inherit_if_not_set(flags, options, "experimental", False)
4749
_inherit_if_not_set(flags, options, "release_gil", False)
4850
_inherit_if_not_set(flags, options, "no_compile", True)
4951
_inherit_if_not_set(flags, options, "use_mlir", False)
52+
_inherit_if_not_set(
53+
flags, options, "_compilation_mode", CompilationMode.KERNEL
54+
)
5055

5156

5257
class DpexKernelTarget(TargetDescriptor):

numba_dpex/core/targets/kernel_target.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55

6+
from enum import IntEnum
67
from functools import cached_property
78

89
import dpnp
@@ -30,6 +31,28 @@
3031
LLVM_SPIRV_ARGS = 112
3132

3233

34+
class CompilationMode(IntEnum):
35+
"""Flags used to determine how a function should be compiled by the
36+
numba_dpex.experimental.dispatcher.KernelDispatcher. Note the functionality
37+
will be merged into numba_dpex.core.kernel_interface.dispatcher in the
38+
future.
39+
40+
KERNEL : Indicates that the function will be compiled into an
41+
LLVM function that has ``spir_kernel`` calling
42+
convention and is compiled down to SPIR-V.
43+
Additionally, the function cannot return any value and
44+
input arguments to the function have to adhere to
45+
"compute follows data" to ensure execution queue
46+
inference.
47+
DEVICE_FUNCTION: Indicates that the function will be compiled into an
48+
LLVM function that has ``spir_func`` calling convention
49+
and will be compiled only into LLVM bitcode.
50+
"""
51+
52+
KERNEL = 1
53+
DEVICE_FUNC = 2
54+
55+
3356
class DpexKernelTypingContext(typing.BaseContext):
3457
"""Custom typing context to support kernel compilation.
3558

numba_dpex/experimental/decorators.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ready to move to numba_dpex.core.
77
"""
88
import inspect
9+
from warnings import warn
910

1011
from numba.core import sigutils
1112
from numba.core.target_extension import (
@@ -14,6 +15,8 @@
1415
target_registry,
1516
)
1617

18+
from numba_dpex.core.targets.kernel_target import CompilationMode
19+
1720
from .target import DPEX_KERNEL_EXP_TARGET_NAME
1821

1922

@@ -30,6 +33,14 @@ def kernel(func_or_sig=None, **options):
3033
"""
3134

3235
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
36+
if "_compilation_mode" in options:
37+
user_compilation_mode = options["_compilation_mode"]
38+
warn(
39+
"_compilation_mode is an internal flag that should not be set "
40+
"in the decorator. The decorator defined option "
41+
f"{user_compilation_mode} is going to be ignored."
42+
)
43+
options["_compilation_mode"] = CompilationMode.KERNEL
3344

3445
# FIXME: The options need to be evaluated and checked here like it is
3546
# done in numba.core.decorators.jit
@@ -80,4 +91,44 @@ def _specialized_kernel_dispatcher(pyfunc):
8091
return _kernel_dispatcher(func)
8192

8293

83-
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = kernel
94+
def device_func(func_or_sig=None, **options):
95+
"""Generates a function with a device-only calling convention, e.g.,
96+
spir_func for SPIR-V based devices.
97+
98+
The decorator is used to compile overloads in the DpexKernelTarget and
99+
users should use the decorator to define functions that are only callable
100+
from inside another device_func or a kernel.
101+
102+
A device_func is not compiled down to device binary IR and instead left as
103+
LLVM IR. It is done so that the function can be inlined fully into the
104+
kernel module from where it is used at the LLVM level, leading to more
105+
optimization opportunities.
106+
107+
Returns:
108+
KernelDispatcher: A KernelDispatcher instance with the
109+
_compilation_mode option set to DEVICE_FUNC.
110+
"""
111+
dispatcher = resolve_dispatcher_from_str(DPEX_KERNEL_EXP_TARGET_NAME)
112+
113+
if "_compilation_mode" in options:
114+
user_compilation_mode = options["_compilation_mode"]
115+
warn(
116+
"_compilation_mode is an internal flag that should not be set "
117+
"in the decorator. The decorator defined option "
118+
f"{user_compilation_mode} is going to be ignored."
119+
)
120+
options["_compilation_mode"] = CompilationMode.DEVICE_FUNC
121+
122+
def _kernel_dispatcher(pyfunc):
123+
return dispatcher(
124+
pyfunc=pyfunc,
125+
targetoptions=options,
126+
)
127+
128+
if func_or_sig is None:
129+
return _kernel_dispatcher
130+
131+
return _kernel_dispatcher(func_or_sig)
132+
133+
134+
jit_registry[target_registry[DPEX_KERNEL_EXP_TARGET_NAME]] = device_func

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515
from numba.core.compiler_lock import global_compiler_lock
1616
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
1717
from numba.core.target_extension import dispatcher_registry, target_registry
18+
from numba.core.types import void
1819
from numba.core.typing.typeof import Purpose, typeof
1920

2021
from numba_dpex import config, spirv_generator
2122
from numba_dpex.core.exceptions import (
2223
ExecutionQueueInferenceError,
24+
KernelHasReturnValueError,
2325
UnsupportedKernelArgumentError,
2426
)
2527
from numba_dpex.core.pipelines import kernel_compiler
28+
from numba_dpex.core.targets.kernel_target import CompilationMode
2629
from numba_dpex.core.types import DpnpNdArray
2730

2831
from .target import DPEX_KERNEL_EXP_TARGET_NAME, dpex_exp_kernel_target
@@ -82,9 +85,10 @@ def _compile_to_spirv(
8285
kernel_func, kernel_fndesc.argtypes
8386
)
8487

85-
# makes sure that the spir_func is completely inlined into the
86-
# spir_kernel wrapper
87-
kernel_library.optimize_final_module()
88+
# Call finalize on the LLVM module. Finalization will result in
89+
# all linking libraries getting linked together and final optimization
90+
# including inlining of functions if an inlining level is specified.
91+
kernel_library.finalize()
8892
# Compiled the LLVM IR to SPIR-V
8993
kernel_spirv_module = spirv_generator.llvm_to_spirv(
9094
kernel_targetctx,
@@ -144,9 +148,15 @@ def _compile_cached(
144148
try:
145149
cres: CompileResult = self._compile_core(args, return_type)
146150

147-
kernel_device_ir_module = self._compile_to_spirv(
148-
cres.library, cres.fndesc, cres.target_context
149-
)
151+
if (
152+
self.targetoptions["_compilation_mode"]
153+
== CompilationMode.KERNEL
154+
):
155+
kernel_device_ir_module: _KernelModule = self._compile_to_spirv(
156+
cres.library, cres.fndesc, cres.target_context
157+
)
158+
else:
159+
kernel_device_ir_module = None
150160

151161
kcres_attrs = []
152162

@@ -282,12 +292,28 @@ def cb_llvm(dur):
282292
with self._compiling_counter:
283293
args, return_type = sigutils.normalize_signature(sig)
284294

285-
try:
286-
self._compiler.check_queue_equivalence_of_args(
287-
self._kernel_name, args
288-
)
289-
except ExecutionQueueInferenceError as eqie:
290-
raise eqie
295+
if (
296+
self.targetoptions["_compilation_mode"]
297+
== CompilationMode.KERNEL
298+
):
299+
# Compute follows data based queue equivalence is only
300+
# evaluated for kernel functions whose arguments are
301+
# supposed to be arrays. For device_func decorated
302+
# functions, the arguments can be scalar and we skip queue
303+
# equivalence check.
304+
try:
305+
self._compiler.check_queue_equivalence_of_args(
306+
self._kernel_name, args
307+
)
308+
except ExecutionQueueInferenceError as eqie:
309+
raise eqie
310+
311+
# A function being compiled in the KERNEL compilation mode
312+
# cannot have a non-void return value
313+
if return_type and return_type != void:
314+
raise KernelHasReturnValueError(
315+
kernel_name=None, return_type=return_type, sig=sig
316+
)
291317

292318
# Don't recompile if signature already exists
293319
existing = self.overloads.get(tuple(args))

0 commit comments

Comments
 (0)