Skip to content

Commit 53166ef

Browse files
Allow third-party backends to add submodules to triton.language.extra (#4503)
Add an optional language directory to backends. The contents of the directory is added to `triton.language.extra` when the wheel is built. Update the existing `triton.language.extra.cuda` and `triton.language.extra.hip` modules to use the new mechanism. The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `It is already tested by python/test/unit/language/test_core.py::test_math_extern`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent ab07e54 commit 53166ef

File tree

8 files changed

+53
-11
lines changed

8 files changed

+53
-11
lines changed

python/setup.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232
@dataclass
3333
class Backend:
3434
name: str
35-
package_data: dict
35+
package_data: list[str]
36+
language_package_data: list[str]
3637
src_dir: str
3738
backend_dir: str
39+
language_dir: str
3840
install_dir: str
3941
is_external: bool
4042

@@ -62,12 +64,22 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
6264
backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend"))
6365
assert os.path.exists(backend_path), f"{backend_path} does not exist!"
6466

67+
language_dir = os.path.abspath(os.path.join(backend_src_dir, "language"))
68+
if not os.path.exists(language_dir):
69+
language_dir = None
70+
6571
for file in ["compiler.py", "driver.py"]:
6672
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"
6773

6874
install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name)
6975
package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)]
70-
return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_path,
76+
77+
language_package_data = []
78+
if language_dir is not None:
79+
language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)]
80+
81+
return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data,
82+
src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
7183
install_dir=install_dir, is_external=is_external)
7284

7385
# Copy all in-tree backends under triton/third_party.
@@ -556,6 +568,19 @@ def add_link_to_backends():
556568
shutil.rmtree(backend.install_dir)
557569
os.symlink(backend.backend_dir, backend.install_dir)
558570

571+
if backend.language_dir:
572+
# Link the contents of each backend's `language` directory into
573+
# `triton.language.extra`.
574+
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "language", "extra"))
575+
for x in os.listdir(backend.language_dir):
576+
src_dir = os.path.join(backend.language_dir, x)
577+
install_dir = os.path.join(extra_dir, x)
578+
if os.path.islink(install_dir):
579+
os.unlink(install_dir)
580+
if os.path.exists(install_dir):
581+
shutil.rmtree(install_dir)
582+
os.symlink(src_dir, install_dir)
583+
559584

560585
def add_link_to_proton():
561586
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton"))
@@ -602,28 +627,49 @@ def run(self):
602627

603628

604629
package_data = {
605-
"triton/tools": ["compile.h", "compile.c"],
606-
**{f"triton/backends/{b.name}": b.package_data
607-
for b in backends},
630+
"triton/tools": ["compile.h", "compile.c"], **{f"triton/backends/{b.name}": b.package_data
631+
for b in backends}, "triton/language/extra": sum(
632+
(b.language_package_data for b in backends), [])
608633
}
609634

610635

636+
def get_language_extra_packages():
637+
packages = []
638+
for backend in backends:
639+
if backend.language_dir is None:
640+
continue
641+
642+
# Walk the `language` directory of each backend to enumerate
643+
# any subpackages, which will be added to `triton.language.extra`.
644+
for dir, dirs, files in os.walk(backend.language_dir, followlinks=True):
645+
if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir:
646+
# Ignore directories with no python files.
647+
# Also ignore the root directory which corresponds to
648+
# "triton/language/extra".
649+
continue
650+
subpackage = os.path.relpath(dir, backend.language_dir)
651+
package = os.path.join("triton/language/extra", subpackage)
652+
packages.append(package)
653+
654+
return list(packages)
655+
656+
611657
def get_packages():
612658
packages = [
613659
"triton",
614660
"triton/_C",
615661
"triton/compiler",
616662
"triton/language",
617663
"triton/language/extra",
618-
"triton/language/extra/cuda",
619-
"triton/language/extra/hip",
620664
"triton/runtime",
621665
"triton/backends",
622666
"triton/tools",
623667
]
624668
packages += [f'triton/backends/{backend.name}' for backend in backends]
669+
packages += get_language_extra_packages()
625670
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
626671
packages += ["triton/profiler"]
672+
627673
return packages
628674

629675

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +0,0 @@
1-
from . import cuda
2-
from . import hip
3-
4-
__all__ = ['cuda', 'hip']

0 commit comments

Comments
 (0)