Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions docs/source/reference/kernel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,87 @@ This is similar to launch configuration in CUDA C/C++:
.. note:: The order of ``stream`` and ``sharedmem`` are reversed in Numba
compared to in CUDA C/C++.

Launch configuration access (advanced)
--------------------------------------

The configured-launch object returned by ``dispatcher[griddim, blockdim, ...]``
exposes launch metadata and callback hooks that can be consumed by advanced
tooling (for example, rewrite passes and extension integrations).

Compile-time launch-config access
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The active launch configuration is available only while compilation is in
progress for a kernel launch.

.. note:: This is compile-time state. If a launch reuses an existing compiled
kernel for the same cache key, no compilation occurs and no compile-time
launch config is set. For launch-config-sensitive kernels, a different
launch configuration can trigger a separate compilation/specialization; see
:ref:`cuda-launch-config-sensitive-compilation`.

.. code-block:: python

from numba import cuda
from numba.cuda import launchconfig

@cuda.jit
def f(x):
x[0] = 1

arr = cuda.device_array(1, dtype="i4")
with launchconfig.capture_compile_config(f) as capture:
f[1, 1](arr) # first launch triggers compilation

cfg = capture["config"]
print(cfg.griddim, cfg.blockdim, cfg.sharedmem)

For use inside compilation passes:

.. code-block:: python

from numba.cuda import launchconfig

cfg = launchconfig.ensure_current_launch_config()
print(cfg.griddim, cfg.blockdim, cfg.sharedmem, cfg.args)

Pre-launch callbacks
~~~~~~~~~~~~~~~~~~~~

Configured launches expose ``pre_launch_callbacks``. Each callback is called
immediately before launch with ``(kernel, launch_config)``.

.. code-block:: python

cfg = f[1, 1]

def log_launch(kernel, cfg):
print(cfg.griddim, cfg.blockdim)

cfg.pre_launch_callbacks.append(log_launch)
cfg(arr)

.. _cuda-launch-config-sensitive-compilation:

Launch-config-sensitive compilation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If code generation depends on launch configuration (for example, a rewrite
pass that inspects ``cfg.blockdim`` and emits different IR), mark the active
launch as launch-config sensitive:

.. code-block:: python

cfg = launchconfig.ensure_current_launch_config()
cfg.mark_kernel_as_launch_config_sensitive()

This instructs the dispatcher/cache machinery to avoid unsafe reuse across
different launch configurations for that kernel path.

.. note:: Launch-config-sensitive cache keying for ``cache=True`` applies to
kernels that are otherwise disk-cacheable. Kernels that require external
linking files are not currently disk-cacheable.

Dispatcher objects also provide several utility methods for inspection and
creating a specialized instance:

Expand Down
123 changes: 122 additions & 1 deletion numba_cuda/numba/cuda/cext/_dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,81 @@
#include "traceback.h"
#include "typeconv.hpp"

static Py_tss_t launch_config_tss_key = Py_tss_NEEDS_INIT;
static const char *launch_config_kw = "__numba_cuda_launch_config__";

static int
launch_config_tss_init(void)
{
if (PyThread_tss_create(&launch_config_tss_key) != 0) {
PyErr_SetString(PyExc_RuntimeError,
"Failed to initialize launch config TLS");
return -1;
}
return 0;
}

static PyObject *
launch_config_get_borrowed(void)
{
return (PyObject *) PyThread_tss_get(&launch_config_tss_key);
}

static int
launch_config_set(PyObject *obj)
{
PyObject *old = (PyObject *) PyThread_tss_get(&launch_config_tss_key);
if (obj != NULL) {
Py_INCREF(obj);
}
if (PyThread_tss_set(&launch_config_tss_key, (void *) obj) != 0) {
Py_XDECREF(obj);
PyErr_SetString(PyExc_RuntimeError,
"Failed to set launch config TLS");
return -1;
}
Py_XDECREF(old);
return 0;
}

class LaunchConfigGuard {
public:
explicit LaunchConfigGuard(PyObject *value)
: prev(NULL), active(false), requested(value != NULL)
{
if (!requested) {
return;
}
prev = launch_config_get_borrowed();
Py_XINCREF(prev);
if (launch_config_set(value) != 0) {
Py_XDECREF(prev);
prev = NULL;
return;
}
active = true;
}

bool failed(void) const
{
return requested && !active;
}

~LaunchConfigGuard(void)
{
if (!active) {
return;
}
launch_config_set(prev);
Py_XDECREF(prev);
}

private:
PyObject *prev;
bool active;
bool requested;
};

/*
* Notes on the C_TRACE macro:
*
Expand Down Expand Up @@ -840,6 +915,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
PyObject *cfunc;
PyThreadState *ts = PyThreadState_Get();
PyObject *locals = NULL;
PyObject *launch_config = NULL;

/* If compilation is enabled, ensure that an exact match is found and if
* not compile one */
Expand All @@ -855,9 +931,26 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
goto CLEANUP;
}
}
if (kws != NULL) {
launch_config = PyDict_GetItemString(kws, launch_config_kw);
if (launch_config != NULL) {
Py_INCREF(launch_config);
if (PyDict_DelItemString(kws, launch_config_kw) < 0) {
Py_DECREF(launch_config);
launch_config = NULL;
goto CLEANUP;
}
if (launch_config == Py_None) {
Py_DECREF(launch_config);
launch_config = NULL;
}
}
}
if (self->fold_args) {
if (find_named_args(self, &args, &kws))
if (find_named_args(self, &args, &kws)) {
Py_XDECREF(launch_config);
return NULL;
}
}
else
Py_INCREF(args);
Expand Down Expand Up @@ -913,6 +1006,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
} else if (matches == 0) {
/* No matching definition */
if (self->can_compile) {
LaunchConfigGuard guard(launch_config);
if (guard.failed()) {
retval = NULL;
goto CLEANUP;
}
retval = cuda_compile_only(self, args, kws, locals);
} else if (self->fallbackdef) {
/* Have object fallback */
Expand All @@ -924,6 +1022,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
}
} else if (self->can_compile) {
/* Ambiguous, but are allowed to compile */
LaunchConfigGuard guard(launch_config);
if (guard.failed()) {
retval = NULL;
goto CLEANUP;
}
retval = cuda_compile_only(self, args, kws, locals);
} else {
/* Ambiguous */
Expand All @@ -935,6 +1038,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
if (tys != prealloc)
delete[] tys;
Py_DECREF(args);
Py_XDECREF(launch_config);

return retval;
}
Expand Down Expand Up @@ -1040,10 +1144,23 @@ static PyObject *compute_fingerprint(PyObject *self, PyObject *args)
return typeof_compute_fingerprint(val);
}

static PyObject *
get_current_launch_config(PyObject *self, PyObject *args)
{
PyObject *config = launch_config_get_borrowed();
if (config == NULL) {
Py_RETURN_NONE;
}
Py_INCREF(config);
return config;
}

static PyMethodDef ext_methods[] = {
#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL }
declmethod(typeof_init),
declmethod(compute_fingerprint),
{ "get_current_launch_config", (PyCFunction)get_current_launch_config,
METH_NOARGS, NULL },
{ NULL },
#undef declmethod
};
Expand All @@ -1055,6 +1172,10 @@ MOD_INIT(_dispatcher) {
if (m == NULL)
return MOD_ERROR_VAL;

if (launch_config_tss_init() != 0) {
return MOD_ERROR_VAL;
}

DispatcherType.tp_new = PyType_GenericNew;
if (PyType_Ready(&DispatcherType) < 0) {
return MOD_ERROR_VAL;
Expand Down
12 changes: 12 additions & 0 deletions numba_cuda/numba/cuda/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numba.cuda.core.interpreter import Interpreter

from numba.cuda import cgutils, typing, lowering, nvvmutils, utils
from numba.cuda import launchconfig
from numba.cuda.api import get_current_device
from numba.cuda.codegen import ExternalCodeLibrary

Expand Down Expand Up @@ -398,6 +399,14 @@ def run_pass(self, state):
"""
lowered = state["cr"]
signature = typing.signature(state.return_type, *state.args)
launch_cfg = launchconfig.current_launch_config()
if (
launch_cfg is not None
and launch_cfg.is_kernel_launch_config_sensitive()
):
if state.metadata is None:
state.metadata = {}
state.metadata["launch_config_sensitive"] = True

state.cr = cuda_compile_result(
typing_context=state.typingctx,
Expand All @@ -408,6 +417,9 @@ def run_pass(self, state):
call_helper=lowered.call_helper,
signature=signature,
fndesc=lowered.fndesc,
# Preserve metadata populated by rewrite passes (e.g. launch-config
# sensitivity) so downstream consumers can act on it.
metadata=state.metadata,
)
return True

Expand Down
Loading