Skip to content

Commit e3ab295

Browse files
authored
[BUILD] Add a stable symlink to llvm in the triton cache (#5234)
Currently the llvm path changes every time the pin updates which makes it annoying to use the included tools. e.g. I use the tablegen language server, but currently need to update my editor config every time the llvm pin changes. This adds a stable symlink which for me is `~/.triton/llvm/llvm-macos-x64`. This will always point to the most recent version of llvm used to build triton. As a bonus this also refactors the symlink update code which was copy-pasted a few times.
1 parent 6404fbb commit e3ab295

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

python/setup.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from io import BytesIO
1515
from distutils.command.clean import clean
1616
from pathlib import Path
17-
from typing import List, NamedTuple, Optional
17+
from typing import List, Optional
1818

1919
from setuptools import Extension, setup
2020
from setuptools.command.build_ext import build_ext
@@ -148,13 +148,15 @@ def is_offline_build() -> bool:
148148
# --- third party packages -----
149149

150150

151-
class Package(NamedTuple):
151+
@dataclass
152+
class Package:
152153
package: str
153154
name: str
154155
url: str
155156
include_flag: str
156157
lib_flag: str
157158
syspath_var_name: str
159+
sym_name: Optional[str] = None
158160

159161

160162
# json
@@ -207,8 +209,10 @@ def get_llvm_package_info():
207209
with open(llvm_hash_path, "r") as llvm_hash_file:
208210
rev = llvm_hash_file.read(8)
209211
name = f"llvm-{rev}-{system_suffix}"
212+
# Create a stable symlink that doesn't include revision
213+
sym_name = f"llvm-{system_suffix}"
210214
url = f"https://oaitriton.blob.core.windows.net/public/llvm-builds/{name}.tar.gz"
211-
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
215+
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH", sym_name=sym_name)
212216

213217

214218
def open_url(url):
@@ -233,6 +237,19 @@ def get_triton_cache_path():
233237
return os.path.join(user_home, ".triton")
234238

235239

240+
def update_symlink(link_path, source_path):
241+
source_path = Path(source_path)
242+
link_path = Path(link_path)
243+
244+
if link_path.is_symlink():
245+
link_path.unlink()
246+
elif link_path.exists():
247+
shutil.rmtree(link_path)
248+
249+
print(f"creating symlink: {link_path} -> {source_path}", file=sys.stderr)
250+
link_path.symlink_to(source_path, target_is_directory=True)
251+
252+
236253
def get_thirdparty_packages(packages: list):
237254
triton_cache_path = get_triton_cache_path()
238255
thirdparty_cmake_args = []
@@ -269,6 +286,10 @@ def get_thirdparty_packages(packages: list):
269286
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
270287
if p.lib_flag:
271288
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
289+
if p.sym_name is not None:
290+
sym_link_path = os.path.join(package_root_dir, p.sym_name)
291+
update_symlink(sym_link_path, package_dir)
292+
272293
return thirdparty_cmake_args
273294

274295

@@ -565,11 +586,7 @@ def get_platform_dependent_src_path(subdir):
565586

566587
def add_link_to_backends():
567588
for backend in backends:
568-
if os.path.islink(backend.install_dir):
569-
os.unlink(backend.install_dir)
570-
if os.path.exists(backend.install_dir):
571-
shutil.rmtree(backend.install_dir)
572-
os.symlink(backend.backend_dir, backend.install_dir)
589+
update_symlink(backend.install_dir, backend.backend_dir)
573590

574591
if backend.language_dir:
575592
# Link the contents of each backend's `language` directory into
@@ -578,21 +595,13 @@ def add_link_to_backends():
578595
for x in os.listdir(backend.language_dir):
579596
src_dir = os.path.join(backend.language_dir, x)
580597
install_dir = os.path.join(extra_dir, x)
581-
if os.path.islink(install_dir):
582-
os.unlink(install_dir)
583-
if os.path.exists(install_dir):
584-
shutil.rmtree(install_dir)
585-
os.symlink(src_dir, install_dir)
598+
update_symlink(install_dir, src_dir)
586599

587600

588601
def add_link_to_proton():
589602
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton"))
590603
proton_install_dir = os.path.join(os.path.dirname(__file__), "triton", "profiler")
591-
if os.path.islink(proton_install_dir):
592-
os.unlink(proton_install_dir)
593-
if os.path.exists(proton_install_dir):
594-
shutil.rmtree(proton_install_dir)
595-
os.symlink(proton_dir, proton_install_dir)
604+
update_symlink(proton_install_dir, proton_dir)
596605

597606

598607
def add_links():

0 commit comments

Comments
 (0)