Skip to content

Commit 96f96e9

Browse files
Merge OpenAI Triton commit ca4d957 (#4083)
This PR change the Triton base from f1f9ed9 to ca4d957 (Apr 29). Pass rate: 92.08% Please do not squash and merge this PR.
2 parents 82b4715 + 7e42d2c commit 96f96e9

File tree

22 files changed

+588
-238
lines changed

22 files changed

+588
-238
lines changed

MANIFEST.in

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ graft include
55
graft lib
66
graft python/src
77
graft python/test
8-
graft python/triton/backends/amd
9-
graft python/triton/backends/nvidia
10-
graft python/triton/tools/extra/cuda
8+
graft python/triton
119
graft test
1210
graft third_party
1311
graft unittest

bin/RegisterTritonDialects.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
2929
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
3030

31+
#include "nvidia/hopper/include/Transforms/Passes.h"
3132
#include "nvidia/include/Dialect/NVWS/Transforms/Passes.h"
3233
#include "nvidia/include/NVGPUToLLVM/Passes.h"
3334
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
@@ -109,6 +110,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
109110
// NVWS passes
110111
mlir::registerNVWSTransformsPasses();
111112

113+
// NVGPU transform passes
114+
mlir::registerNVHopperTransformsPasses();
115+
112116
registry.insert<
113117
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
114118
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Python API
2929
- :doc:`triton.language <python-api/triton.language>`
3030
- :doc:`triton.testing <python-api/triton.testing>`
3131
- :doc:`Triton semantics <python-api/triton-semantics>`
32+
- :doc:`triton.language.extra.cuda <python-api/triton.language.extra.cuda>`
3233

3334

3435
.. toctree::
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
triton.language.extra.cuda
2+
==========================
3+
4+
.. currentmodule:: triton.language.extra.cuda
5+
6+
Programmatic Dependent Launch
7+
-----------------------------
8+
9+
.. autosummary::
10+
:toctree: generated
11+
:nosignatures:
12+
13+
gdc_wait
14+
gdc_launch_dependents

python/triton/_utils.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
1+
from __future__ import annotations
2+
13
from functools import reduce
4+
from typing import Any, Callable, TYPE_CHECKING, Union
5+
6+
if TYPE_CHECKING:
7+
from .language import core
8+
IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
9+
ObjPath = tuple[int, ...]
210

311

4-
def get_iterable_path(iterable, path):
5-
return reduce(lambda a, idx: a[idx], path, iterable)
12+
def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
13+
return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
614

715

8-
def set_iterable_path(iterable, path, val):
16+
def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
17+
assert len(path) != 0
918
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
10-
prev[path[-1]] = val
19+
prev[path[-1]] = val # type: ignore[index]
1120

1221

13-
def find_paths_if(iterable, pred):
22+
def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
1423
from .language import core
15-
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
16-
ret = dict()
24+
is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
25+
# We need to use dict so that ordering is maintained, while set doesn't guarantee order
26+
ret: dict[ObjPath, None] = {}
1727

18-
def _impl(current, path):
19-
path = (path[0], ) if len(path) == 1 else tuple(path)
28+
def _impl(path: tuple[int, ...], current: Any):
2029
if is_iterable(current):
2130
for idx, item in enumerate(current):
22-
_impl(item, path + (idx, ))
31+
_impl((*path, idx), item)
2332
elif pred(path, current):
24-
if len(path) == 1:
25-
ret[(path[0], )] = None
26-
else:
27-
ret[tuple(path)] = None
28-
29-
if is_iterable(iterable):
30-
_impl(iterable, [])
31-
elif pred(list(), iterable):
32-
ret = {tuple(): None}
33-
else:
34-
ret = dict()
33+
ret[path] = None
34+
35+
_impl((), iterable)
36+
3537
return list(ret.keys())
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Programmatic Dependent Launch
3+
=====================
4+
This script demonstrates the use of programmatic dependent launch (PDL) ontop of the vector-add example using Triton.
5+
6+
For CUDA reference on programmatic dependent launch see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization.
7+
For PTX reference on programmatic dependent launch see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol.
8+
9+
.. code-block:: bash
10+
python 11-programmatic-dependent-launch.py
11+
"""
12+
13+
import torch
14+
import triton
15+
import triton.language as tl
16+
17+
18+
def is_cuda():
19+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
20+
21+
22+
def supports_pdl():
23+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
24+
25+
26+
# In this example
27+
@triton.jit
28+
def add_kernel(x_ptr, #
29+
y_ptr, #
30+
output_ptr, #
31+
n_elements, #
32+
BLOCK_SIZE: tl.constexpr, #
33+
USE_GDC: tl.constexpr, #
34+
):
35+
pid = tl.program_id(axis=0)
36+
block_start = pid * BLOCK_SIZE
37+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
38+
mask = offsets < n_elements
39+
if USE_GDC:
40+
# GDC wait waits for ALL programs in the the prior kernel to complete before continuing.
41+
# This ensures any memory operations happen before the wait in program order,
42+
# e.g. if the prior kernel writes to x or y the new values will be visible.
43+
tl.extra.cuda.gdc_wait()
44+
45+
x = tl.load(x_ptr + offsets, mask=mask)
46+
y = tl.load(y_ptr + offsets, mask=mask)
47+
if USE_GDC:
48+
# GDC launch dependents hints the runtime system to launch dependent kernels.
49+
# These dependent kernels must also be launched with PDL enabled.
50+
# Once GDC launch has been issued by ALL programs or
51+
# programs have finished, the dependent grid can begin if there are enough resources.
52+
# Note: this by itself provides no additional memory-ordering guarentees, unlike `gdc_wait`
53+
tl.extra.cuda.gdc_launch_dependents()
54+
output = x + y
55+
tl.store(output_ptr + offsets, output, mask=mask)
56+
57+
58+
def add(x: torch.Tensor, y: torch.Tensor, launch_pdl: bool = True):
59+
output = torch.empty_like(x)
60+
assert x.device == y.device and output.device == x.device
61+
n_elements = output.numel()
62+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
63+
add_kernel[grid](
64+
x, y, output, n_elements, BLOCK_SIZE=1024,
65+
USE_GDC=launch_pdl, # set constexpr in kernel to use grid dependence control
66+
launch_pdl=launch_pdl, # launch kernel with PDL flag set enabled
67+
)
68+
return output
69+
70+
71+
def validate(n_elements):
72+
x = torch.rand(n_elements, device="cuda", dtype=torch.float32)
73+
y = torch.rand(n_elements, device="cuda", dtype=torch.float32)
74+
75+
torch_result = x + y
76+
add_result = add(x, y)
77+
78+
torch_vs_add = "✅" if torch.allclose(torch_result, add_result, atol=1.0) else "❌"
79+
print(f"Number of Elements={n_elements} verification naive vs: ", end="")
80+
print(f"add: {torch_vs_add}")
81+
82+
83+
@triton.testing.perf_report(
84+
triton.testing.Benchmark(
85+
x_names=["size"],
86+
x_vals=[2**i for i in range(23, 28, 1)],
87+
x_log=False,
88+
line_arg="provider",
89+
line_vals=["pdl-fp32", "fp32"],
90+
line_names=["PDL", "No PDL"],
91+
styles=[("red", "-"), ("blue", "-")],
92+
ylabel='GB/s',
93+
plot_name="pdl-performance",
94+
args={},
95+
))
96+
def benchmark(size, provider):
97+
x = torch.rand(size, device="cuda", dtype=torch.float32)
98+
y = torch.rand(size, device="cuda", dtype=torch.float32)
99+
100+
quantiles = [0.5, 0.2, 0.8]
101+
102+
fn = lambda: add(x, y, "pdl" in provider)
103+
104+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles, rep=100)
105+
106+
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
107+
return gbps(ms), gbps(max_ms), gbps(min_ms)
108+
109+
110+
if __name__ == "__main__":
111+
112+
if supports_pdl():
113+
validate(1024)
114+
benchmark.run(print_data=True, show_plots=True, save_path=".")
115+
else:
116+
print("PDL is not supported on this device")

setup.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,19 @@
2020
from setuptools.command.build_ext import build_ext
2121
from setuptools.command.build_py import build_py
2222
from setuptools.command.develop import develop
23+
from setuptools.command.egg_info import egg_info
24+
from setuptools.command.install import install
25+
from setuptools.command.sdist import sdist
26+
2327
from dataclasses import dataclass
2428

2529
import pybind11
2630

31+
try:
32+
from setuptools.command.bdist_wheel import bdist_wheel
33+
except ImportError:
34+
from wheel.bdist_wheel import bdist_wheel
35+
2736
try:
2837
from setuptools.command.editable_wheel import editable_wheel
2938
except ImportError:
@@ -602,6 +611,10 @@ def get_package_dirs():
602611
yield ("", "python")
603612

604613
for backend in backends:
614+
# we use symlinks for external plugins
615+
if backend.is_external:
616+
continue
617+
605618
yield (f"triton.backends.{backend.name}", backend.backend_dir)
606619

607620
if backend.language_dir:
@@ -620,8 +633,33 @@ def get_package_dirs():
620633
yield ("triton.profiler", "third_party/proton/proton")
621634

622635

623-
def add_link_to_backends():
636+
def get_packages():
637+
yield from find_packages(where="python")
638+
639+
for backend in backends:
640+
yield f"triton.backends.{backend.name}"
641+
642+
if backend.language_dir:
643+
# Install the contents of each backend's `language` directory into
644+
# `triton.language.extra`.
645+
for x in os.listdir(backend.language_dir):
646+
yield f"triton.language.extra.{x}"
647+
648+
if backend.tools_dir:
649+
# Install the contents of each backend's `tools` directory into
650+
# `triton.tools.extra`.
651+
for x in os.listdir(backend.tools_dir):
652+
yield f"triton.tools.extra.{x}"
653+
654+
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
655+
yield "triton.profiler"
656+
657+
658+
def add_link_to_backends(external_only):
624659
for backend in backends:
660+
if external_only and not backend.is_external:
661+
continue
662+
625663
update_symlink(backend.install_dir, backend.backend_dir)
626664

627665
if backend.language_dir:
@@ -650,23 +688,53 @@ def add_link_to_proton():
650688
update_symlink(proton_install_dir, proton_dir)
651689

652690

653-
def add_links():
654-
add_link_to_backends()
655-
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
691+
def add_links(external_only):
692+
add_link_to_backends(external_only=external_only)
693+
if not external_only and check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
656694
add_link_to_proton()
657695

658696

697+
class plugin_bdist_wheel(bdist_wheel):
698+
699+
def run(self):
700+
add_links(external_only=True)
701+
super().run()
702+
703+
659704
class plugin_develop(develop):
660705

661706
def run(self):
662-
add_links()
707+
add_links(external_only=False)
663708
super().run()
664709

665710

666711
class plugin_editable_wheel(editable_wheel):
667712

668713
def run(self):
669-
add_links()
714+
add_links(external_only=False)
715+
super().run()
716+
717+
718+
class plugin_egg_info(egg_info):
719+
720+
def run(self):
721+
add_links(external_only=True)
722+
super().run()
723+
724+
725+
class plugin_install(install):
726+
727+
def run(self):
728+
add_links(external_only=True)
729+
super().run()
730+
731+
732+
class plugin_sdist(sdist):
733+
734+
def run(self):
735+
for backend in backends:
736+
if backend.is_external:
737+
raise RuntimeError("sdist cannot be used with TRITON_PLUGIN_DIRS")
670738
super().run()
671739

672740

@@ -708,9 +776,6 @@ def get_git_version_suffix():
708776
# keep it separate for easy substitution
709777
TRITON_VERSION = "3.3.0" + get_git_version_suffix() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", "")
710778

711-
package_dirs = dict(get_package_dirs())
712-
extra_packages = [x for x in package_dirs if x != ""]
713-
714779
setup(
715780
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
716781
version=TRITON_VERSION,
@@ -722,17 +787,21 @@ def get_git_version_suffix():
722787
"setuptools>=78.1.0",
723788
"importlib-metadata; python_version < '3.10'",
724789
],
725-
packages=find_packages(where="python") + extra_packages,
726-
package_dir=package_dirs,
790+
packages=list(get_packages()),
791+
package_dir=dict(get_package_dirs()),
727792
entry_points=get_entry_points(),
728793
include_package_data=True,
729794
ext_modules=[CMakeExtension("triton", "triton/_C/")],
730795
cmdclass={
796+
"bdist_wheel": plugin_bdist_wheel,
731797
"build_ext": CMakeBuild,
732798
"build_py": CMakeBuildPy,
733799
"clean": CMakeClean,
734800
"develop": plugin_develop,
735801
"editable_wheel": plugin_editable_wheel,
802+
"egg_info": plugin_egg_info,
803+
"install": plugin_install,
804+
"sdist": plugin_sdist,
736805
},
737806
zip_safe=False,
738807
# for PyPI

0 commit comments

Comments
 (0)