|
32 | 32 | @dataclass |
33 | 33 | class Backend: |
34 | 34 | name: str |
35 | | - package_data: dict |
| 35 | + package_data: list[str] |
| 36 | + language_package_data: list[str] |
36 | 37 | src_dir: str |
37 | 38 | backend_dir: str |
| 39 | + language_dir: str |
38 | 40 | install_dir: str |
39 | 41 | is_external: bool |
40 | 42 |
|
@@ -62,12 +64,22 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = |
62 | 64 | backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend")) |
63 | 65 | assert os.path.exists(backend_path), f"{backend_path} does not exist!" |
64 | 66 |
|
| 67 | + language_dir = os.path.abspath(os.path.join(backend_src_dir, "language")) |
| 68 | + if not os.path.exists(language_dir): |
| 69 | + language_dir = None |
| 70 | + |
65 | 71 | for file in ["compiler.py", "driver.py"]: |
66 | 72 | assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}" |
67 | 73 |
|
68 | 74 | install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name) |
69 | 75 | package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)] |
70 | | - return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_path, |
| 76 | + |
| 77 | + language_package_data = [] |
| 78 | + if language_dir is not None: |
| 79 | + language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)] |
| 80 | + |
| 81 | + return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data, |
| 82 | + src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir, |
71 | 83 | install_dir=install_dir, is_external=is_external) |
72 | 84 |
|
73 | 85 | # Copy all in-tree backends under triton/third_party. |
@@ -556,6 +568,19 @@ def add_link_to_backends(): |
556 | 568 | shutil.rmtree(backend.install_dir) |
557 | 569 | os.symlink(backend.backend_dir, backend.install_dir) |
558 | 570 |
|
| 571 | + if backend.language_dir: |
| 572 | + # Link the contents of each backend's `language` directory into |
| 573 | + # `triton.language.extra`. |
| 574 | + extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "language", "extra")) |
| 575 | + for x in os.listdir(backend.language_dir): |
| 576 | + src_dir = os.path.join(backend.language_dir, x) |
| 577 | + install_dir = os.path.join(extra_dir, x) |
| 578 | + if os.path.islink(install_dir): |
| 579 | + os.unlink(install_dir) |
| 580 | + if os.path.exists(install_dir): |
| 581 | + shutil.rmtree(install_dir) |
| 582 | + os.symlink(src_dir, install_dir) |
| 583 | + |
559 | 584 |
|
560 | 585 | def add_link_to_proton(): |
561 | 586 | proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton")) |
@@ -602,28 +627,49 @@ def run(self): |
602 | 627 |
|
603 | 628 |
|
604 | 629 | package_data = { |
605 | | - "triton/tools": ["compile.h", "compile.c"], |
606 | | - **{f"triton/backends/{b.name}": b.package_data |
607 | | - for b in backends}, |
| 630 | + "triton/tools": ["compile.h", "compile.c"], **{f"triton/backends/{b.name}": b.package_data |
| 631 | + for b in backends}, "triton/language/extra": sum( |
| 632 | + (b.language_package_data for b in backends), []) |
608 | 633 | } |
609 | 634 |
|
610 | 635 |
|
| 636 | +def get_language_extra_packages(): |
| 637 | + packages = [] |
| 638 | + for backend in backends: |
| 639 | + if backend.language_dir is None: |
| 640 | + continue |
| 641 | + |
| 642 | + # Walk the `language` directory of each backend to enumerate |
| 643 | + # any subpackages, which will be added to `triton.language.extra`. |
| 644 | + for dir, dirs, files in os.walk(backend.language_dir, followlinks=True): |
| 645 | + if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir: |
| 646 | + # Ignore directories with no python files. |
| 647 | + # Also ignore the root directory which corresponds to |
| 648 | + # "triton/language/extra". |
| 649 | + continue |
| 650 | + subpackage = os.path.relpath(dir, backend.language_dir) |
| 651 | + package = os.path.join("triton/language/extra", subpackage) |
| 652 | + packages.append(package) |
| 653 | + |
| 654 | + return list(packages) |
| 655 | + |
| 656 | + |
611 | 657 | def get_packages(): |
612 | 658 | packages = [ |
613 | 659 | "triton", |
614 | 660 | "triton/_C", |
615 | 661 | "triton/compiler", |
616 | 662 | "triton/language", |
617 | 663 | "triton/language/extra", |
618 | | - "triton/language/extra/cuda", |
619 | | - "triton/language/extra/hip", |
620 | 664 | "triton/runtime", |
621 | 665 | "triton/backends", |
622 | 666 | "triton/tools", |
623 | 667 | ] |
624 | 668 | packages += [f'triton/backends/{backend.name}' for backend in backends] |
| 669 | + packages += get_language_extra_packages() |
625 | 670 | if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON |
626 | 671 | packages += ["triton/profiler"] |
| 672 | + |
627 | 673 | return packages |
628 | 674 |
|
629 | 675 |
|
|
0 commit comments