Skip to content

Commit 4d1ec3e

Browse files
authored
[TOOLS] Move backend template src files to third_party/tools folder (#5411)
Triton AOT compiler use template files for cuda program interface generation now, I think it's more reasonable to put them to `third_party/tools` folder instead of public folder. Also `package_data` in `setup.py` needs change (follows `language/extra`), because `compile.py` itself will be executed from the python interpreter. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 2f8fe0d commit 4d1ec3e

File tree

5 files changed

+45
-18
lines changed

5 files changed

+45
-18
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ python/triton/backends/
2323
# Language extras
2424
python/triton/language/extra
2525

26+
# Tools extras
27+
python/triton/tools/extra
28+
2629
# Proton
2730
python/triton/profiler
2831

python/setup.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ class Backend:
3434
name: str
3535
package_data: List[str]
3636
language_package_data: List[str]
37+
tools_package_data: List[str]
3738
src_dir: str
3839
backend_dir: str
3940
language_dir: Optional[str]
41+
tools_dir: Optional[str]
4042
install_dir: str
4143
is_external: bool
4244

@@ -68,6 +70,10 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
6870
if not os.path.exists(language_dir):
6971
language_dir = None
7072

73+
tools_dir = os.path.abspath(os.path.join(backend_src_dir, "tools"))
74+
if not os.path.exists(tools_dir):
75+
tools_dir = None
76+
7177
for file in ["compiler.py", "driver.py"]:
7278
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"
7379

@@ -78,9 +84,13 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
7884
if language_dir is not None:
7985
language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)]
8086

87+
tools_package_data = []
88+
if tools_dir is not None:
89+
tools_package_data = [f"{os.path.relpath(p, tools_dir)}/*" for p, _, _, in os.walk(tools_dir)]
90+
8191
return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data,
82-
src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
83-
install_dir=install_dir, is_external=is_external)
92+
tools_package_data=tools_package_data, src_dir=backend_src_dir, backend_dir=backend_path,
93+
language_dir=language_dir, tools_dir=tools_dir, install_dir=install_dir, is_external=is_external)
8494

8595
# Copy all in-tree backends under triton/third_party.
8696
@staticmethod
@@ -598,6 +608,15 @@ def add_link_to_backends():
598608
install_dir = os.path.join(extra_dir, x)
599609
update_symlink(install_dir, src_dir)
600610

611+
if backend.tools_dir:
612+
# Link the contents of each backend's `tools` directory into
613+
# `triton.tools.extra`.
614+
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "tools", "extra"))
615+
for x in os.listdir(backend.tools_dir):
616+
src_dir = os.path.join(backend.tools_dir, x)
617+
install_dir = os.path.join(extra_dir, x)
618+
update_symlink(install_dir, src_dir)
619+
601620

602621
def add_link_to_proton():
603622
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton"))
@@ -640,28 +659,31 @@ def run(self):
640659

641660

642661
package_data = {
643-
"triton/tools": ["compile.h", "compile.c"], **{f"triton/backends/{b.name}": b.package_data
644-
for b in backends}, "triton/language/extra": sum(
645-
(b.language_package_data for b in backends), [])
662+
"triton/tools/extra": sum((b.tools_package_data for b in backends), []),
663+
**{f"triton/backends/{b.name}": b.package_data
664+
for b in backends}, "triton/language/extra": sum((b.language_package_data for b in backends), [])
646665
}
647666

648667

649-
def get_language_extra_packages():
668+
def get_extra_packages(extra_name):
650669
packages = []
670+
extra_file_extensions = {"language": (".py"), "tools": (".c", ".h", ".cpp")}
671+
assert extra_name in extra_file_extensions, f"{extra_name} extra is not valid"
672+
651673
for backend in backends:
652-
if backend.language_dir is None:
674+
backend_extra_dir = getattr(backend, f"{extra_name}_dir", None)
675+
if backend_extra_dir is None:
653676
continue
654677

655-
# Walk the `language` directory of each backend to enumerate
656-
# any subpackages, which will be added to `triton.language.extra`.
657-
for dir, dirs, files in os.walk(backend.language_dir, followlinks=True):
658-
if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir:
659-
# Ignore directories with no python files.
660-
# Also ignore the root directory which corresponds to
661-
# "triton/language/extra".
678+
# Walk the specified directory of each backend to enumerate
679+
# any subpackages, which will be added to extra_package.
680+
for dir, dirs, files in os.walk(backend_extra_dir, followlinks=True):
681+
if not any(f for f in files if f.endswith(extra_file_extensions[extra_name])) or dir == backend_extra_dir:
682+
# Ignore directories with no relevant files
683+
# or the root directory
662684
continue
663-
subpackage = os.path.relpath(dir, backend.language_dir)
664-
package = os.path.join("triton/language/extra", subpackage)
685+
subpackage = os.path.relpath(dir, backend_extra_dir)
686+
package = os.path.join(f"triton/{extra_name}/extra", subpackage)
665687
packages.append(package)
666688

667689
return list(packages)
@@ -677,9 +699,11 @@ def get_packages():
677699
"triton/runtime",
678700
"triton/backends",
679701
"triton/tools",
702+
"triton/tools/extra",
680703
]
681704
packages += [f'triton/backends/{backend.name}' for backend in backends]
682-
packages += get_language_extra_packages()
705+
packages += get_extra_packages("language")
706+
packages += get_extra_packages("tools")
683707
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
684708
packages += ["triton/profiler"]
685709

python/triton/tools/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,6 @@ def constexpr(s):
151151
"_placeholder": "",
152152
}
153153
for ext in ['h', 'c']:
154-
template_path = Path(__file__).parent / f"compile.{ext}"
154+
template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}"
155155
with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp:
156156
fp.write(Path(template_path).read_text().format(**params))
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)