2020from setuptools .command .build_ext import build_ext
2121from setuptools .command .build_py import build_py
2222from setuptools .command .develop import develop
23+ from setuptools .command .egg_info import egg_info
24+ from setuptools .command .install import install
25+ from setuptools .command .sdist import sdist
26+
2327from dataclasses import dataclass
2428
2529import pybind11
2630
31+ try :
32+ from setuptools .command .bdist_wheel import bdist_wheel
33+ except ImportError :
34+ from wheel .bdist_wheel import bdist_wheel
35+
2736try :
2837 from setuptools .command .editable_wheel import editable_wheel
2938except ImportError :
@@ -602,6 +611,10 @@ def get_package_dirs():
602611 yield ("" , "python" )
603612
604613 for backend in backends :
614+ # we use symlinks for external plugins
615+ if backend .is_external :
616+ continue
617+
605618 yield (f"triton.backends.{ backend .name } " , backend .backend_dir )
606619
607620 if backend .language_dir :
@@ -620,8 +633,33 @@ def get_package_dirs():
620633 yield ("triton.profiler" , "third_party/proton/proton" )
621634
622635
623- def add_link_to_backends ():
636+ def get_packages ():
637+ yield from find_packages (where = "python" )
638+
639+ for backend in backends :
640+ yield f"triton.backends.{ backend .name } "
641+
642+ if backend .language_dir :
643+ # Install the contents of each backend's `language` directory into
644+ # `triton.language.extra`.
645+ for x in os .listdir (backend .language_dir ):
646+ yield f"triton.language.extra.{ x } "
647+
648+ if backend .tools_dir :
649+ # Install the contents of each backend's `tools` directory into
650+ # `triton.tools.extra`.
651+ for x in os .listdir (backend .tools_dir ):
652+ yield f"triton.tools.extra.{ x } "
653+
654+ if check_env_flag ("TRITON_BUILD_PROTON" , "ON" ): # Default ON
655+ yield "triton.profiler"
656+
657+
658+ def add_link_to_backends (external_only ):
624659 for backend in backends :
660+ if external_only and not backend .is_external :
661+ continue
662+
625663 update_symlink (backend .install_dir , backend .backend_dir )
626664
627665 if backend .language_dir :
@@ -650,23 +688,53 @@ def add_link_to_proton():
650688 update_symlink (proton_install_dir , proton_dir )
651689
652690
653- def add_links ():
654- add_link_to_backends ()
655- if check_env_flag ("TRITON_BUILD_PROTON" , "ON" ): # Default ON
691+ def add_links (external_only ):
692+ add_link_to_backends (external_only = external_only )
693+ if not external_only and check_env_flag ("TRITON_BUILD_PROTON" , "ON" ): # Default ON
656694 add_link_to_proton ()
657695
658696
697+ class plugin_bdist_wheel (bdist_wheel ):
698+
699+ def run (self ):
700+ add_links (external_only = True )
701+ super ().run ()
702+
703+
659704class plugin_develop (develop ):
660705
661706 def run (self ):
662- add_links ()
707+ add_links (external_only = False )
663708 super ().run ()
664709
665710
666711class plugin_editable_wheel (editable_wheel ):
667712
668713 def run (self ):
669- add_links ()
714+ add_links (external_only = False )
715+ super ().run ()
716+
717+
718+ class plugin_egg_info (egg_info ):
719+
720+ def run (self ):
721+ add_links (external_only = True )
722+ super ().run ()
723+
724+
725+ class plugin_install (install ):
726+
727+ def run (self ):
728+ add_links (external_only = True )
729+ super ().run ()
730+
731+
732+ class plugin_sdist (sdist ):
733+
734+ def run (self ):
735+ for backend in backends :
736+ if backend .is_external :
737+ raise RuntimeError ("sdist cannot be used with TRITON_PLUGIN_DIRS" )
670738 super ().run ()
671739
672740
@@ -708,9 +776,6 @@ def get_git_version_suffix():
708776# keep it separate for easy substitution
709777TRITON_VERSION = "3.3.0" + get_git_version_suffix () + os .environ .get ("TRITON_WHEEL_VERSION_SUFFIX" , "" )
710778
711- package_dirs = dict (get_package_dirs ())
712- extra_packages = [x for x in package_dirs if x != "" ]
713-
714779setup (
715780 name = os .environ .get ("TRITON_WHEEL_NAME" , "triton" ),
716781 version = TRITON_VERSION ,
@@ -722,17 +787,21 @@ def get_git_version_suffix():
722787 "setuptools>=78.1.0" ,
723788 "importlib-metadata; python_version < '3.10'" ,
724789 ],
725- packages = find_packages ( where = "python" ) + extra_packages ,
726- package_dir = package_dirs ,
790+ packages = list ( get_packages ()) ,
791+ package_dir = dict ( get_package_dirs ()) ,
727792 entry_points = get_entry_points (),
728793 include_package_data = True ,
729794 ext_modules = [CMakeExtension ("triton" , "triton/_C/" )],
730795 cmdclass = {
796+ "bdist_wheel" : plugin_bdist_wheel ,
731797 "build_ext" : CMakeBuild ,
732798 "build_py" : CMakeBuildPy ,
733799 "clean" : CMakeClean ,
734800 "develop" : plugin_develop ,
735801 "editable_wheel" : plugin_editable_wheel ,
802+ "egg_info" : plugin_egg_info ,
803+ "install" : plugin_install ,
804+ "sdist" : plugin_sdist ,
736805 },
737806 zip_safe = False ,
738807 # for PyPI
0 commit comments