Skip to content

Commit 784b550

Browse files
authored
Use symlinks for external plugins to fix TRITON_PLUGIN_DIRS (#6627)
Disable the new backend installing logic for external plugins, since `package_dir` does not accept absolute paths. Instead, use a hybrid approach where in-tree backends are installed using the new logic and external backends are symlinked. This implies that source distributions cannot be created when using external plugins. Fixes #6612 ------ <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because it touches build system only. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 33869db commit 784b550

File tree

2 files changed

+81
-14
lines changed

2 files changed

+81
-14
lines changed

MANIFEST.in

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ graft include
55
graft lib
66
graft python/src
77
graft python/test
8-
graft python/triton/backends/amd
9-
graft python/triton/backends/nvidia
10-
graft python/triton/tools/extra/cuda
8+
graft python/triton
119
graft test
1210
graft third_party
1311
graft unittest

setup.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,19 @@
2020
from setuptools.command.build_ext import build_ext
2121
from setuptools.command.build_py import build_py
2222
from setuptools.command.develop import develop
23+
from setuptools.command.egg_info import egg_info
24+
from setuptools.command.install import install
25+
from setuptools.command.sdist import sdist
26+
2327
from dataclasses import dataclass
2428

2529
import pybind11
2630

31+
try:
32+
from setuptools.command.bdist_wheel import bdist_wheel
33+
except ImportError:
34+
from wheel.bdist_wheel import bdist_wheel
35+
2736
try:
2837
from setuptools.command.editable_wheel import editable_wheel
2938
except ImportError:
@@ -587,6 +596,10 @@ def get_package_dirs():
587596
yield ("", "python")
588597

589598
for backend in backends:
599+
# we use symlinks for external plugins
600+
if backend.is_external:
601+
continue
602+
590603
yield (f"triton.backends.{backend.name}", backend.backend_dir)
591604

592605
if backend.language_dir:
@@ -605,8 +618,33 @@ def get_package_dirs():
605618
yield ("triton.profiler", "third_party/proton/proton")
606619

607620

608-
def add_link_to_backends():
621+
def get_packages():
622+
yield from find_packages(where="python")
623+
624+
for backend in backends:
625+
yield f"triton.backends.{backend.name}"
626+
627+
if backend.language_dir:
628+
# Install the contents of each backend's `language` directory into
629+
# `triton.language.extra`.
630+
for x in os.listdir(backend.language_dir):
631+
yield f"triton.language.extra.{x}"
632+
633+
if backend.tools_dir:
634+
# Install the contents of each backend's `tools` directory into
635+
# `triton.tools.extra`.
636+
for x in os.listdir(backend.tools_dir):
637+
yield f"triton.tools.extra.{x}"
638+
639+
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
640+
yield "triton.profiler"
641+
642+
643+
def add_link_to_backends(external_only):
609644
for backend in backends:
645+
if external_only and not backend.is_external:
646+
continue
647+
610648
update_symlink(backend.install_dir, backend.backend_dir)
611649

612650
if backend.language_dir:
@@ -635,23 +673,53 @@ def add_link_to_proton():
635673
update_symlink(proton_install_dir, proton_dir)
636674

637675

638-
def add_links():
639-
add_link_to_backends()
640-
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
676+
def add_links(external_only):
677+
add_link_to_backends(external_only=external_only)
678+
if not external_only and check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
641679
add_link_to_proton()
642680

643681

682+
class plugin_bdist_wheel(bdist_wheel):
683+
684+
def run(self):
685+
add_links(external_only=True)
686+
super().run()
687+
688+
644689
class plugin_develop(develop):
645690

646691
def run(self):
647-
add_links()
692+
add_links(external_only=False)
648693
super().run()
649694

650695

651696
class plugin_editable_wheel(editable_wheel):
652697

653698
def run(self):
654-
add_links()
699+
add_links(external_only=False)
700+
super().run()
701+
702+
703+
class plugin_egg_info(egg_info):
704+
705+
def run(self):
706+
add_links(external_only=True)
707+
super().run()
708+
709+
710+
class plugin_install(install):
711+
712+
def run(self):
713+
add_links(external_only=True)
714+
super().run()
715+
716+
717+
class plugin_sdist(sdist):
718+
719+
def run(self):
720+
for backend in backends:
721+
if backend.is_external:
722+
raise RuntimeError("sdist cannot be used with TRITON_PLUGIN_DIRS")
655723
super().run()
656724

657725

@@ -693,9 +761,6 @@ def get_git_version_suffix():
693761
# keep it separate for easy substitution
694762
TRITON_VERSION = "3.3.0" + get_git_version_suffix() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", "")
695763

696-
package_dirs = dict(get_package_dirs())
697-
extra_packages = [x for x in package_dirs if x != ""]
698-
699764
setup(
700765
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
701766
version=TRITON_VERSION,
@@ -707,17 +772,21 @@ def get_git_version_suffix():
707772
"setuptools>=40.8.0",
708773
"importlib-metadata; python_version < '3.10'",
709774
],
710-
packages=find_packages(where="python") + extra_packages,
711-
package_dir=package_dirs,
775+
packages=list(get_packages()),
776+
package_dir=dict(get_package_dirs()),
712777
entry_points=get_entry_points(),
713778
include_package_data=True,
714779
ext_modules=[CMakeExtension("triton", "triton/_C/")],
715780
cmdclass={
781+
"bdist_wheel": plugin_bdist_wheel,
716782
"build_ext": CMakeBuild,
717783
"build_py": CMakeBuildPy,
718784
"clean": CMakeClean,
719785
"develop": plugin_develop,
720786
"editable_wheel": plugin_editable_wheel,
787+
"egg_info": plugin_egg_info,
788+
"install": plugin_install,
789+
"sdist": plugin_sdist,
721790
},
722791
zip_safe=False,
723792
# for PyPI

0 commit comments

Comments
 (0)