You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[NVIDIA] Enable Programmatic Dependent Launch in Triton (#6394)
Programmatic Dependent Launch (PDL) enables kernels within the same CUDA
stream to overlap while programmatically resolving inter-kernel
dependencies. This allows consecutive kernels to overlap their ramp-down
and ramp-up periods, efficiently hiding prologue latencies. Inter-kernel
dependencies are resolved using Grid Dependency Control (GDC), which
ensures that a kernel waits before reading memory written by the
preceding kernel. This feature is utilized in libraries including
[CUTLASS](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/dependent_kernel_launch.md).
Effectively utilizing PDL in Triton requires using
`tl.extra.cuda.gdc_wait()` to wait for the prior kernel to finish
writing its results. The most straightforward approach is to execute
`tl.extra.cuda.gdc_wait()` before any `tl.load`, based on the
conservative assumption that the prior kernel may be launched with PDL
and can write to any memory location.
When using PDL, `tl.extra.cuda.gdc_launch_dependents()` allows for the
current kernel to trigger the next kernel to start. See the [CUDA
documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization)
for more information.
We utilize this feature in a simple non-persistent kernel with a
conservative approach to inter-kernel dependencies on Blackwell in
tutorial 11. This kernel achieves up to a 15% speedup:

More advanced patterns with PDL we can achieve up to 33% performance
benefits on back-to-back layers in LLMs (see [_LLM Inference Performance
and Optimization on NVIDIA GB200
NVL72_](https://www.nvidia.com/en-us/on-demand/session/gtc25-s72503/) at
GTC 2025 for more details).
---------
Co-authored-by: dePaul Miller <[email protected]>
Co-authored-by: peterbell10 <[email protected]>
This script demonstrates the use of programmatic dependent launch (PDL) ontop of the vector-add example using Triton.
5
+
6
+
For CUDA reference on programmatic dependent launch see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization.
7
+
For PTX reference on programmatic dependent launch see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol.
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
270
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
268
271
void *params[] = {{ {', '.join(params)} }};
269
272
if (gridX*gridY*gridZ > 0) {{
270
-
if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
0 commit comments