Skip to content

Commit ca4d957

Browse files
[NVIDIA] Enable Programmatic Dependent Launch in Triton (#6394)
Programmatic Dependent Launch (PDL) enables kernels within the same CUDA stream to overlap while programmatically resolving inter-kernel dependencies. This allows consecutive kernels to overlap their ramp-down and ramp-up periods, efficiently hiding prologue latencies. Inter-kernel dependencies are resolved using Grid Dependency Control (GDC), which ensures that a kernel waits before reading memory written by the preceding kernel. This feature is utilized in libraries including [CUTLASS](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/dependent_kernel_launch.md). Effectively utilizing PDL in Triton requires using `tl.extra.cuda.gdc_wait()` to wait for the prior kernel to finish writing its results. The most straightforward approach is to execute `tl.extra.cuda.gdc_wait()` before any `tl.load`, based on the conservative assumption that the prior kernel may be launched with PDL and can write to any memory location. When using PDL, `tl.extra.cuda.gdc_launch_dependents()` allows for the current kernel to trigger the next kernel to start. See the [CUDA documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) for more information. We utilize this feature in a simple non-persistent kernel with a conservative approach to inter-kernel dependencies on Blackwell in tutorial 11. This kernel achieves up to a 15% speedup: ![pdl_performance](https://github.com/user-attachments/assets/5d0a9a3b-38f6-4ae1-9a94-7ade22099a4f) More advanced patterns with PDL we can achieve up to 33% performance benefits on back-to-back layers in LLMs (see [_LLM Inference Performance and Optimization on NVIDIA GB200 NVL72_](https://www.nvidia.com/en-us/on-demand/session/gtc25-s72503/) at GTC 2025 for more details). --------- Co-authored-by: dePaul Miller <[email protected]> Co-authored-by: peterbell10 <[email protected]>
1 parent b4dcd8e commit ca4d957

File tree

7 files changed

+244
-63
lines changed

7 files changed

+244
-63
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Python API
2626
- :doc:`triton.language <python-api/triton.language>`
2727
- :doc:`triton.testing <python-api/triton.testing>`
2828
- :doc:`Triton semantics <python-api/triton-semantics>`
29+
- :doc:`triton.language.extra.cuda <python-api/triton.language.extra.cuda>`
2930

3031

3132
.. toctree::
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
triton.language.extra.cuda
2+
==========================
3+
4+
.. currentmodule:: triton.language.extra.cuda
5+
6+
Programmatic Dependent Launch
7+
-----------------------------
8+
9+
.. autosummary::
10+
:toctree: generated
11+
:nosignatures:
12+
13+
gdc_wait
14+
gdc_launch_dependents
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Programmatic Dependent Launch
3+
=====================
4+
This script demonstrates the use of programmatic dependent launch (PDL) ontop of the vector-add example using Triton.
5+
6+
For CUDA reference on programmatic dependent launch see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization.
7+
For PTX reference on programmatic dependent launch see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol.
8+
9+
.. code-block:: bash
10+
python 11-programmatic-dependent-launch.py
11+
"""
12+
13+
import torch
14+
import triton
15+
import triton.language as tl
16+
17+
18+
def is_cuda():
19+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
20+
21+
22+
def supports_pdl():
23+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
24+
25+
26+
# In this example
27+
@triton.jit
28+
def add_kernel(x_ptr, #
29+
y_ptr, #
30+
output_ptr, #
31+
n_elements, #
32+
BLOCK_SIZE: tl.constexpr, #
33+
USE_GDC: tl.constexpr, #
34+
):
35+
pid = tl.program_id(axis=0)
36+
block_start = pid * BLOCK_SIZE
37+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
38+
mask = offsets < n_elements
39+
if USE_GDC:
40+
# GDC wait waits for ALL programs in the the prior kernel to complete before continuing.
41+
# This ensures any memory operations happen before the wait in program order,
42+
# e.g. if the prior kernel writes to x or y the new values will be visible.
43+
tl.extra.cuda.gdc_wait()
44+
45+
x = tl.load(x_ptr + offsets, mask=mask)
46+
y = tl.load(y_ptr + offsets, mask=mask)
47+
if USE_GDC:
48+
# GDC launch dependents hints the runtime system to launch dependent kernels.
49+
# These dependent kernels must also be launched with PDL enabled.
50+
# Once GDC launch has been issued by ALL programs or
51+
# programs have finished, the dependent grid can begin if there are enough resources.
52+
# Note: this by itself provides no additional memory-ordering guarentees, unlike `gdc_wait`
53+
tl.extra.cuda.gdc_launch_dependents()
54+
output = x + y
55+
tl.store(output_ptr + offsets, output, mask=mask)
56+
57+
58+
def add(x: torch.Tensor, y: torch.Tensor, launch_pdl: bool = True):
59+
output = torch.empty_like(x)
60+
assert x.device == y.device and output.device == x.device
61+
n_elements = output.numel()
62+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
63+
add_kernel[grid](
64+
x, y, output, n_elements, BLOCK_SIZE=1024,
65+
USE_GDC=launch_pdl, # set constexpr in kernel to use grid dependence control
66+
launch_pdl=launch_pdl, # launch kernel with PDL flag set enabled
67+
)
68+
return output
69+
70+
71+
def validate(n_elements):
72+
x = torch.rand(n_elements, device="cuda", dtype=torch.float32)
73+
y = torch.rand(n_elements, device="cuda", dtype=torch.float32)
74+
75+
torch_result = x + y
76+
add_result = add(x, y)
77+
78+
torch_vs_add = "✅" if torch.allclose(torch_result, add_result, atol=1.0) else "❌"
79+
print(f"Number of Elements={n_elements} verification naive vs: ", end="")
80+
print(f"add: {torch_vs_add}")
81+
82+
83+
@triton.testing.perf_report(
84+
triton.testing.Benchmark(
85+
x_names=["size"],
86+
x_vals=[2**i for i in range(23, 28, 1)],
87+
x_log=False,
88+
line_arg="provider",
89+
line_vals=["pdl-fp32", "fp32"],
90+
line_names=["PDL", "No PDL"],
91+
styles=[("red", "-"), ("blue", "-")],
92+
ylabel='GB/s',
93+
plot_name="pdl-performance",
94+
args={},
95+
))
96+
def benchmark(size, provider):
97+
x = torch.rand(size, device="cuda", dtype=torch.float32)
98+
y = torch.rand(size, device="cuda", dtype=torch.float32)
99+
100+
quantiles = [0.5, 0.2, 0.8]
101+
102+
fn = lambda: add(x, y, "pdl" in provider)
103+
104+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles, rep=100)
105+
106+
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
107+
return gbps(ms), gbps(max_ms), gbps(min_ms)
108+
109+
110+
if __name__ == "__main__":
111+
112+
if supports_pdl():
113+
validate(1024)
114+
benchmark.run(print_data=True, show_plots=True, save_path=".")
115+
else:
116+
print("PDL is not supported on this device")

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class CUDAOptions:
108108
ptx_version: int = None
109109
enable_fp_fusion: bool = True
110110
launch_cooperative_grid: bool = False
111+
launch_pdl: bool = False
111112
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
112113
deprecated_fp8_dtypes: Tuple[str] = ()
113114
default_dot_input_precision: str = "tf32"

third_party/nvidia/backend/driver.py

Lines changed: 67 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def ty_to_cpp(ty):
122122
}[ty]
123123

124124

125+
_BASE_ARGS_FORMAT = "iiiKKppOOOOO"
126+
127+
125128
def make_launcher(constants, signature):
126129

127130
def _expand_signature(sig, output):
@@ -184,7 +187,7 @@ def format_of(ty):
184187
signature = {i: s for i, s in enumerate(expand_signature)}
185188

186189
args_format = ''.join([format_of(ty) for ty in signature.values()])
187-
format = "iiiKKpOOOOO" + args_format
190+
format = _BASE_ARGS_FORMAT + args_format
188191

189192
flat_signature = []
190193
for sig in signature.values():
@@ -264,67 +267,65 @@ def format_of(ty):
264267
return cuLaunchKernelExHandle;
265268
}}
266269
267-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
270+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
268271
void *params[] = {{ {', '.join(params)} }};
269272
if (gridX*gridY*gridZ > 0) {{
270-
if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
271-
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
272-
}} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
273-
CUlaunchAttribute launchAttr[1];
273+
// 4 attributes that we can currently pass maxmimum
274+
CUlaunchAttribute launchAttr[4];
275+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
276+
if (cuLaunchKernelExHandle == NULL) {{
277+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
278+
}}
279+
CUlaunchConfig config;
280+
config.gridDimX = gridX;
281+
config.gridDimY = gridY;
282+
config.gridDimZ = gridZ;
283+
284+
if (num_ctas != 1) {{
285+
config.gridDimX *= clusterDimX;
286+
config.gridDimY *= clusterDimY;
287+
config.gridDimZ *= clusterDimZ;
288+
}}
289+
290+
config.blockDimX = 32 * num_warps;
291+
config.blockDimY = 1;
292+
config.blockDimZ = 1;
293+
config.sharedMemBytes = shared_memory;
294+
config.hStream = stream;
295+
config.attrs = launchAttr;
296+
int num_attrs = 0;
297+
298+
if (launch_pdl != 0) {{
299+
CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
300+
launchAttr[num_attrs] = pdlAttr;
301+
++num_attrs;
302+
}}
303+
304+
if (launch_cooperative_grid != 0) {{
274305
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
275-
launchAttr[0] = coopAttr;
276-
277-
CUlaunchConfig config;
278-
config.gridDimX = gridX;
279-
config.gridDimY = gridY;
280-
config.gridDimZ = gridZ;
281-
config.blockDimX = 32 * num_warps;
282-
config.blockDimY = 1;
283-
config.blockDimZ = 1;
284-
config.sharedMemBytes = shared_memory;
285-
config.hStream = stream;
286-
config.attrs = launchAttr;
287-
config.numAttrs = 1;
288-
289-
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
290-
if (cuLaunchKernelExHandle == NULL) {{
291-
cuLaunchKernelExHandle = getLaunchKernelExHandle();
292-
}}
293-
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
294-
295-
}} else {{
296-
CUlaunchAttribute launchAttr[3];
297-
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
298-
launchAttr[0].value.clusterDim.x = clusterDimX;
299-
launchAttr[0].value.clusterDim.y = clusterDimY;
300-
launchAttr[0].value.clusterDim.z = clusterDimZ;
301-
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
302-
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
303-
304-
unsigned numAttrs = 2;
305-
if (0 != launch_cooperative_grid) {{
306-
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
307-
launchAttr[2] = coopAttr;
308-
numAttrs = 3;
309-
}}
310-
311-
CUlaunchConfig config;
312-
config.gridDimX = gridX * clusterDimX;
313-
config.gridDimY = gridY * clusterDimY;
314-
config.gridDimZ = gridZ * clusterDimZ;
315-
config.blockDimX = 32 * num_warps;
316-
config.blockDimY = 1;
317-
config.blockDimZ = 1;
318-
config.sharedMemBytes = shared_memory;
319-
config.hStream = stream;
320-
config.attrs = launchAttr;
321-
config.numAttrs = numAttrs;
322-
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
323-
if (cuLaunchKernelExHandle == NULL) {{
324-
cuLaunchKernelExHandle = getLaunchKernelExHandle();
325-
}}
326-
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
306+
launchAttr[num_attrs] = coopAttr;
307+
++num_attrs;
308+
}}
309+
310+
if (num_ctas != 1) {{
311+
CUlaunchAttribute clusterAttr = {{}};
312+
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
313+
clusterAttr.value.clusterDim.x = clusterDimX;
314+
clusterAttr.value.clusterDim.y = clusterDimY;
315+
clusterAttr.value.clusterDim.z = clusterDimZ;
316+
launchAttr[num_attrs] = clusterAttr;
317+
++num_attrs;
318+
319+
CUlaunchAttribute clusterSchedulingAttr = {{}};
320+
clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
321+
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
322+
launchAttr[num_attrs] = clusterSchedulingAttr;
323+
++num_attrs;
327324
}}
325+
326+
config.numAttrs = num_attrs;
327+
328+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
328329
}}
329330
}}
330331
@@ -444,14 +445,15 @@ def format_of(ty):
444445
uint64_t _stream;
445446
uint64_t _function;
446447
int launch_cooperative_grid;
448+
int launch_pdl;
447449
PyObject *launch_enter_hook = NULL;
448450
PyObject *launch_exit_hook = NULL;
449451
PyObject *kernel_metadata = NULL;
450452
PyObject *launch_metadata = NULL;
451453
PyObject *global_scratch_obj = NULL;
452454
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
453455
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
454-
&_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
456+
&_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj,
455457
&kernel_metadata, &launch_metadata,
456458
&launch_enter_hook, &launch_exit_hook{args_list})) {{
457459
return NULL;
@@ -485,7 +487,7 @@ def format_of(ty):
485487
{newline.join(ptr_decls)}
486488
{newline.join(tma_decls)}
487489
Py_BEGIN_ALLOW_THREADS;
488-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
490+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
489491
Py_END_ALLOW_THREADS;
490492
if (PyErr_Occurred()) {{
491493
return NULL;
@@ -584,8 +586,8 @@ def wrap_handle_tensordesc(launcher, tensordesc_meta):
584586
return launcher
585587

586588
def inner(*args):
587-
meta_args = args[:11]
588-
raw_kernel_args = args[11:]
589+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
590+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
589591
tensordesc_idx = 0
590592
final_args = []
591593
for i, arg in enumerate(raw_kernel_args):
@@ -619,6 +621,7 @@ def __init__(self, src, metadata):
619621
self.global_scratch_size = metadata.global_scratch_size
620622
self.global_scratch_align = metadata.global_scratch_align
621623
self.launch_cooperative_grid = metadata.launch_cooperative_grid
624+
self.launch_pdl = metadata.launch_pdl
622625

623626
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
624627
if self.global_scratch_size > 0:
@@ -627,7 +630,8 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
627630
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
628631
else:
629632
global_scratch = None
630-
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
633+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
634+
global_scratch, *args)
631635

632636

633637
class CudaDriver(GPUDriver):

third_party/nvidia/language/cuda/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from . import libdevice
22

33
from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80)
4+
from .gdc import (gdc_launch_dependents, gdc_wait)
45

56
__all__ = [
67
"libdevice",
@@ -10,4 +11,6 @@
1011
"smid",
1112
"convert_custom_float8_sm70",
1213
"convert_custom_float8_sm80",
14+
"gdc_launch_dependents",
15+
"gdc_wait",
1316
]

0 commit comments

Comments
 (0)