Skip to content

Commit 05922fc

Browse files
authored
Implement launch config infrastructure. (#804)
## Summary This PR introduces two related pieces of launch-config infrastructure needed by `cuda.coop` single-phase work: 1. A low-overhead launch-config API that can be consumed from arbitrary numba-cuda compilation stages (including rewrites), plus pre-launch callback registration on configured launches. 2. Launch-config-sensitive (LCS) plumbing so kernels that depend on launch configuration are specialized and cached correctly across launch configs. ## Background and motivation `cuda.coop` single-phase rewriting needs compile-time access to launch configuration details (grid/block/shared memory/launch args) and a way to register pre-launch hooks from rewrite time (for launch-time kernel argument handling without requiring `@cuda.jit(extensions=...)`). An earlier implementation (PR #288) provided this via Python `contextvars`, but review feedback showed launch overhead was too high. This branch reimplements the mechanism through C-extension TLS plumbing in `_dispatcher.cpp`, with negligible overhead in the launch micro-benchmark. From `bench-launch-overhead.out` (us/launch, baseline vs contextvar vs v2): - 0 args: `5.56` vs `7.29` (+31.1%) vs `5.56` (+0.0%) - 1 arg: `7.53` vs `9.18` (+21.8%) vs `7.55` (+0.2%) - 2 args: `8.90` vs `10.64` (+19.5%) vs `8.97` (+0.8%) - 3 args: `10.31` vs `12.50` (+21.3%) vs `10.37` (+0.5%) - 4 args: `11.82` vs `13.56` (+14.7%) vs `11.92` (+0.8%) ## What this PR adds ### 1) Launch-config API with low launch overhead - C-extension (`numba_cuda/numba/cuda/cext/_dispatcher.cpp`) now carries the active launch config in thread-local storage only during compilation paths. - Python API in `numba_cuda/numba/cuda/launchconfig.py`: - `current_launch_config()` - `ensure_current_launch_config()` - `capture_compile_config()` - Configured launches expose: - launch metadata (`griddim`, `blockdim`, `sharedmem`, `args`, `dispatcher`) - `pre_launch_callbacks` for just-in-time launch-time hook registration. ### 2) Launch-config-sensitive compilation/caching - Explicit LCS marker API on `_LaunchConfiguration` (`numba_cuda/numba/cuda/dispatcher.py`): - `mark_kernel_as_launch_config_sensitive()` - `get_kernel_launch_config_sensitive()` - `is_kernel_launch_config_sensitive()` - CUDA backend (`numba_cuda/numba/cuda/compiler.py`) promotes that mark into compile metadata (`state.metadata["launch_config_sensitive"] = True`). - Dispatcher/cache behavior for LCS kernels: - per-launch-config dispatcher specialization routing - per-launch-config disk-cache keys - `.lcs` marker file indicating launch-config-sensitive cache entries. ## Why the LCS piece is required Without LCS, cache keys are signature-based only, so a kernel compiled once for launch config A can be reused for launch config B without rerunning rewrite. That breaks launch-config-dependent rewrite behavior. Concrete observed behavior: - Runtime cache (single process): - Launch `[1, 32]`: rewrite runs, callback registered. - Launch `[1, 64]` without LCS: existing kernel reused, rewrite does not run, callback for the 64-config path is never registered. - With LCS marking: second launch recompiles under a distinct launch-config specialization, so rewrite/callback registration runs for 64. - Disk cache (cross process): - Process 1 compiles and caches launch `[1, 32]`. - Process 2 launches `[1, 64]` without LCS: 32-config artifact can be reused from disk (no rewrite for 64 path). - With LCS marking: process 2 misses on 64-specific cache key and compiles a 64-specific variant. - LCS intentionally preserves exact cache hits for matching launch configs. It does not force recompilation when the launch-config key already matches. So the LCS plumbing is what makes launch-config-dependent rewrite decisions correct under both in-memory and disk cache reuse. Scope note for `cuda.coop` today: - `cuda.coop` frequently injects LTO-IR/linking files during compilation. - numba-cuda currently does not disk-cache kernels with linking files, so for those paths the immediate LCS correctness benefit is runtime/in-memory cache behavior across launch configs. - Disk-cache LCS behavior applies to launch-config-sensitive kernels that are otherwise disk-cacheable (and remains relevant for future linked-code cache support). ## Safety behavior - If an LCS kernel is loaded from disk but the `.lcs` marker is missing, we treat that cache state as unsafe, force recompile, and re-mark. - If marking fails (e.g. filesystem error), disk caching is disabled for safety (fallback to `NullCache`) to avoid unsafe reuse. ## Out of scope - Cache invalidation keyed on `numba_cuda.__version__` (handled by PR #800). Note that PR #800 should be merged and presumably a release cut before this PR is merged--that allows downstream projects like cuda.coop to pin accordingly.
1 parent de4642c commit 05922fc

File tree

11 files changed

+1135
-19
lines changed

11 files changed

+1135
-19
lines changed

docs/source/reference/kernel.rst

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,109 @@ creating a specialized instance:
6363
.. autoclass:: numba.cuda.dispatcher.CUDADispatcher
6464
:members: inspect_asm, inspect_llvm, inspect_sass, inspect_types,
6565
get_regs_per_thread, specialize, specialized, extensions, forall,
66+
mark_launch_config_sensitive,
6667
get_shared_mem_per_block, get_max_threads_per_block,
6768
get_const_mem_size, get_local_mem_per_thread
6869

70+
Launch configuration access (advanced)
71+
--------------------------------------
72+
73+
The configured-launch object returned by ``dispatcher[griddim, blockdim, ...]``
74+
exposes launch metadata and callback hooks that can be consumed by advanced
75+
tooling (for example, rewrite passes and extension integrations).
76+
77+
Compile-time launch-config access
78+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79+
80+
The active launch configuration is available only while compilation is in
81+
progress for a kernel launch. Most consumers should access it from code that
82+
is already running inside compilation, such as rewrite passes, compiler
83+
extensions, or other compile-time helpers.
84+
85+
.. note:: This is compile-time state. If a launch reuses an existing compiled
86+
kernel for the same cache key, no compilation occurs and no compile-time
87+
launch config is set. For launch-config-sensitive kernels, a different
88+
launch configuration can trigger a separate compilation/specialization; see
89+
:ref:`cuda-launch-config-sensitive-compilation`.
90+
91+
Primary path for compile-time consumers:
92+
93+
.. code-block:: python
94+
95+
from numba.cuda import launchconfig
96+
97+
cfg = launchconfig.ensure_current_launch_config()
98+
print(cfg.griddim, cfg.blockdim, cfg.sharedmem, cfg.args)
99+
100+
Use ``ensure_current_launch_config()`` when an active launch-triggered
101+
compilation is required. Use ``current_launch_config()`` instead when the
102+
absence of compile-time launch state is expected and should return ``None``.
103+
104+
For external introspection or testing:
105+
106+
``capture_compile_config()`` is a convenience helper for observing the
107+
compile-time launch configuration from ordinary Python code outside compilation
108+
itself. Most user code does not need it, but it can be useful in tests,
109+
debugging, or tooling that wants to inspect the launch configuration that
110+
triggered compilation.
111+
112+
.. code-block:: python
113+
114+
from numba import cuda
115+
from numba.cuda import launchconfig
116+
117+
@cuda.jit
118+
def f(x):
119+
x[0] = 1
120+
121+
arr = cuda.device_array(1, dtype="i4")
122+
with launchconfig.capture_compile_config(f) as capture:
123+
f[1, 1](arr) # first launch triggers compilation
124+
125+
cfg = capture["config"]
126+
print(cfg.griddim, cfg.blockdim, cfg.sharedmem)
127+
128+
Pre-launch callbacks
129+
~~~~~~~~~~~~~~~~~~~~
130+
131+
Configured launches expose ``pre_launch_callbacks``. Each callback is called
132+
immediately before launch with ``(kernel, launch_config)``.
133+
134+
.. warning:: Pre-launch callbacks must not invoke CUDA APIs or launch CUDA
135+
work. This use is not supported, is not tested, and has undefined
136+
behavior; it may deadlock or fail in other ways.
137+
138+
.. code-block:: python
139+
140+
cfg = f[1, 1]
141+
142+
def log_launch(kernel, cfg):
143+
print(cfg.griddim, cfg.blockdim)
144+
145+
cfg.pre_launch_callbacks.append(log_launch)
146+
cfg(arr)
147+
148+
.. _cuda-launch-config-sensitive-compilation:
149+
150+
Launch-config-sensitive compilation
151+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
152+
153+
If code generation depends on launch configuration (for example, a rewrite
154+
pass that inspects ``cfg.blockdim`` and emits different IR), mark the active
155+
dispatcher as launch-config sensitive:
156+
157+
.. code-block:: python
158+
159+
cfg = launchconfig.ensure_current_launch_config()
160+
cfg.dispatcher.mark_launch_config_sensitive()
161+
162+
This instructs the dispatcher/cache machinery to avoid unsafe reuse across
163+
different launch configurations for that kernel path.
164+
165+
.. note:: Launch-config-sensitive cache keying for ``cache=True`` applies to
166+
kernels that are otherwise disk-cacheable. Kernels that require external
167+
linking files are not currently disk-cacheable.
168+
69169

70170
Kernel Arguments
71171
----------------

numba_cuda/numba/cuda/cext/_dispatcher.cpp

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,116 @@
1313
#include "traceback.h"
1414
#include "typeconv.hpp"
1515

16+
static Py_tss_t launch_config_tss_key = Py_tss_NEEDS_INIT;
17+
static Py_tss_t launch_args_tss_key = Py_tss_NEEDS_INIT;
18+
static const char *launch_config_kw = "__numba_cuda_launch_config__";
19+
20+
static int
21+
launch_config_tss_init(void)
22+
{
23+
if (PyThread_tss_create(&launch_config_tss_key) != 0) {
24+
PyErr_SetString(PyExc_RuntimeError,
25+
"Failed to initialize launch config TLS");
26+
return -1;
27+
}
28+
return 0;
29+
}
30+
31+
static int
32+
launch_args_tss_init(void)
33+
{
34+
if (PyThread_tss_create(&launch_args_tss_key) != 0) {
35+
PyErr_SetString(PyExc_RuntimeError,
36+
"Failed to initialize launch args TLS");
37+
return -1;
38+
}
39+
return 0;
40+
}
41+
42+
static PyObject *
43+
launch_config_get_borrowed(void)
44+
{
45+
return (PyObject *) PyThread_tss_get(&launch_config_tss_key);
46+
}
47+
48+
static PyObject *
49+
launch_args_get_borrowed(void)
50+
{
51+
return (PyObject *) PyThread_tss_get(&launch_args_tss_key);
52+
}
53+
54+
static int
55+
launch_config_set(PyObject *obj)
56+
{
57+
PyObject *old = (PyObject *) PyThread_tss_get(&launch_config_tss_key);
58+
if (obj != NULL) {
59+
Py_INCREF(obj);
60+
}
61+
if (PyThread_tss_set(&launch_config_tss_key, (void *) obj) != 0) {
62+
Py_XDECREF(obj);
63+
PyErr_SetString(PyExc_RuntimeError,
64+
"Failed to set launch config TLS");
65+
return -1;
66+
}
67+
Py_XDECREF(old);
68+
return 0;
69+
}
70+
71+
static int
72+
launch_args_set(PyObject *obj)
73+
{
74+
PyObject *old = (PyObject *) PyThread_tss_get(&launch_args_tss_key);
75+
if (obj != NULL) {
76+
Py_INCREF(obj);
77+
}
78+
if (PyThread_tss_set(&launch_args_tss_key, (void *) obj) != 0) {
79+
Py_XDECREF(obj);
80+
PyErr_SetString(PyExc_RuntimeError,
81+
"Failed to set launch args TLS");
82+
return -1;
83+
}
84+
Py_XDECREF(old);
85+
return 0;
86+
}
87+
88+
class LaunchConfigGuard {
89+
public:
90+
explicit LaunchConfigGuard(PyObject *value)
91+
: prev(NULL), active(false), requested(value != NULL)
92+
{
93+
if (!requested) {
94+
return;
95+
}
96+
prev = launch_config_get_borrowed();
97+
Py_XINCREF(prev);
98+
if (launch_config_set(value) != 0) {
99+
Py_XDECREF(prev);
100+
prev = NULL;
101+
return;
102+
}
103+
active = true;
104+
}
105+
106+
bool failed(void) const
107+
{
108+
return requested && !active;
109+
}
110+
111+
~LaunchConfigGuard(void)
112+
{
113+
if (!active) {
114+
return;
115+
}
116+
launch_config_set(prev);
117+
Py_XDECREF(prev);
118+
}
119+
120+
private:
121+
PyObject *prev;
122+
bool active;
123+
bool requested;
124+
};
125+
16126
/*
17127
* Notes on the C_TRACE macro:
18128
*
@@ -840,6 +950,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
840950
PyObject *cfunc;
841951
PyThreadState *ts = PyThreadState_Get();
842952
PyObject *locals = NULL;
953+
PyObject *launch_config = NULL;
843954

844955
/* If compilation is enabled, ensure that an exact match is found and if
845956
* not compile one */
@@ -855,9 +966,26 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
855966
goto CLEANUP;
856967
}
857968
}
969+
if (kws != NULL) {
970+
launch_config = PyDict_GetItemString(kws, launch_config_kw);
971+
if (launch_config != NULL) {
972+
Py_INCREF(launch_config);
973+
if (PyDict_DelItemString(kws, launch_config_kw) < 0) {
974+
Py_DECREF(launch_config);
975+
launch_config = NULL;
976+
goto CLEANUP;
977+
}
978+
if (launch_config == Py_None) {
979+
Py_DECREF(launch_config);
980+
launch_config = NULL;
981+
}
982+
}
983+
}
858984
if (self->fold_args) {
859-
if (find_named_args(self, &args, &kws))
985+
if (find_named_args(self, &args, &kws)) {
986+
Py_XDECREF(launch_config);
860987
return NULL;
988+
}
861989
}
862990
else
863991
Py_INCREF(args);
@@ -913,6 +1041,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
9131041
} else if (matches == 0) {
9141042
/* No matching definition */
9151043
if (self->can_compile) {
1044+
LaunchConfigGuard guard(launch_config);
1045+
if (guard.failed()) {
1046+
retval = NULL;
1047+
goto CLEANUP;
1048+
}
9161049
retval = cuda_compile_only(self, args, kws, locals);
9171050
} else if (self->fallbackdef) {
9181051
/* Have object fallback */
@@ -924,6 +1057,11 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
9241057
}
9251058
} else if (self->can_compile) {
9261059
/* Ambiguous, but are allowed to compile */
1060+
LaunchConfigGuard guard(launch_config);
1061+
if (guard.failed()) {
1062+
retval = NULL;
1063+
goto CLEANUP;
1064+
}
9271065
retval = cuda_compile_only(self, args, kws, locals);
9281066
} else {
9291067
/* Ambiguous */
@@ -935,6 +1073,7 @@ Dispatcher_cuda_call(Dispatcher *self, PyObject *args, PyObject *kws)
9351073
if (tys != prealloc)
9361074
delete[] tys;
9371075
Py_DECREF(args);
1076+
Py_XDECREF(launch_config);
9381077

9391078
return retval;
9401079
}
@@ -1040,10 +1179,54 @@ static PyObject *compute_fingerprint(PyObject *self, PyObject *args)
10401179
return typeof_compute_fingerprint(val);
10411180
}
10421181

1182+
static PyObject *
1183+
get_current_launch_config(PyObject *self, PyObject *args)
1184+
{
1185+
PyObject *config = launch_config_get_borrowed();
1186+
if (config == NULL) {
1187+
Py_RETURN_NONE;
1188+
}
1189+
Py_INCREF(config);
1190+
return config;
1191+
}
1192+
1193+
static PyObject *
1194+
get_current_launch_args(PyObject *self, PyObject *args)
1195+
{
1196+
PyObject *launch_args = launch_args_get_borrowed();
1197+
if (launch_args == NULL) {
1198+
Py_RETURN_NONE;
1199+
}
1200+
Py_INCREF(launch_args);
1201+
return launch_args;
1202+
}
1203+
1204+
static PyObject *
1205+
swap_current_launch_args(PyObject *self, PyObject *arg)
1206+
{
1207+
PyObject *new_args = arg == Py_None ? NULL : arg;
1208+
PyObject *old_args = launch_args_get_borrowed();
1209+
Py_XINCREF(old_args);
1210+
if (launch_args_set(new_args) != 0) {
1211+
Py_XDECREF(old_args);
1212+
return NULL;
1213+
}
1214+
if (old_args == NULL) {
1215+
Py_RETURN_NONE;
1216+
}
1217+
return old_args;
1218+
}
1219+
10431220
static PyMethodDef ext_methods[] = {
10441221
#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL }
10451222
declmethod(typeof_init),
10461223
declmethod(compute_fingerprint),
1224+
{ "get_current_launch_config", (PyCFunction)get_current_launch_config,
1225+
METH_NOARGS, NULL },
1226+
{ "get_current_launch_args", (PyCFunction)get_current_launch_args,
1227+
METH_NOARGS, NULL },
1228+
{ "swap_current_launch_args", (PyCFunction)swap_current_launch_args,
1229+
METH_O, NULL },
10471230
{ NULL },
10481231
#undef declmethod
10491232
};
@@ -1055,6 +1238,13 @@ MOD_INIT(_dispatcher) {
10551238
if (m == NULL)
10561239
return MOD_ERROR_VAL;
10571240

1241+
if (launch_config_tss_init() != 0) {
1242+
return MOD_ERROR_VAL;
1243+
}
1244+
if (launch_args_tss_init() != 0) {
1245+
return MOD_ERROR_VAL;
1246+
}
1247+
10581248
DispatcherType.tp_new = PyType_GenericNew;
10591249
if (PyType_Ready(&DispatcherType) < 0) {
10601250
return MOD_ERROR_VAL;

numba_cuda/numba/cuda/compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from numba.cuda.core.interpreter import Interpreter
1616

1717
from numba.cuda import cgutils, typing, lowering, nvvmutils, utils
18+
from numba.cuda import launchconfig
1819
from numba.cuda.api import get_current_device
1920
from numba.cuda.codegen import ExternalCodeLibrary
2021

@@ -398,6 +399,14 @@ def run_pass(self, state):
398399
"""
399400
lowered = state["cr"]
400401
signature = typing.signature(state.return_type, *state.args)
402+
launch_cfg = launchconfig.current_launch_config()
403+
if (
404+
launch_cfg is not None
405+
and launch_cfg._is_kernel_launch_config_sensitive()
406+
):
407+
if state.metadata is None:
408+
state.metadata = {}
409+
state.metadata["launch_config_sensitive"] = True
401410

402411
state.cr = cuda_compile_result(
403412
typing_context=state.typingctx,
@@ -408,6 +417,9 @@ def run_pass(self, state):
408417
call_helper=lowered.call_helper,
409418
signature=signature,
410419
fndesc=lowered.fndesc,
420+
# Preserve metadata populated by rewrite passes (e.g. launch-config
421+
# sensitivity) so downstream consumers can act on it.
422+
metadata=state.metadata,
411423
)
412424
return True
413425

0 commit comments

Comments
 (0)