diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 960009595e..aae62030e4 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -23,7 +23,7 @@ def quiet(): sys.stdout, sys.stderr = old_stdout, old_stderr -def _build(name, src, srcdir, library_dirs, include_dirs, libraries): +def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]): suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible @@ -74,6 +74,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + cc_cmd += extra_compile_args if os.getenv("VERBOSE"): print(" ".join(cc_cmd)) @@ -81,8 +82,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): ret = subprocess.check_call(cc_cmd) if ret == 0: return so - # fallback on setuptools - extra_compile_args = [] # extra arguments extra_link_args = [] # create extension module diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 1db9b2c202..74358bc88a 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -5,6 +5,7 @@ import tempfile from pathlib import Path from functools import cached_property +from typing import Optional from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager @@ -14,7 +15,7 @@ from packaging.specifiers import SpecifierSet -def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]: +def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]: """ Looks for the sycl library in known places. @@ -22,12 +23,11 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]: include_dir: list of include directories to pass to compiler. Returns: - enriched include_dir and library_dir. + enriched include_dir and libsycl.so location. Raises: AssertionError: if library was not found. """ - library_dir = [] include_dir = include_dir.copy() assertion_message = ("sycl headers not found, please install `icpx` compiler, " "or provide `ONEAPI_ROOT` environment " @@ -35,7 +35,7 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]: if shutil.which("icpx"): # only `icpx` compiler knows where sycl runtime binaries and header files are - return include_dir, library_dir + return include_dir, None oneapi_root = os.getenv("ONEAPI_ROOT") if oneapi_root: @@ -43,7 +43,7 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]: os.path.join(oneapi_root, "compiler/latest/include"), os.path.join(oneapi_root, "compiler/latest/include/sycl") ] - return include_dir, library_dir + return include_dir, None try: sycl_rt = importlib.metadata.metadata("intel-sycl-rt") @@ -53,15 +53,16 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]: if Version(sycl_rt.get("version", "0.0.0")) in SpecifierSet("<2025.0.0a1"): raise AssertionError(assertion_message) + sycl_dir = None for f in importlib.metadata.files("intel-sycl-rt"): # sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders # being add: include and include/sycl. if f.name == "sycl.hpp": include_dir += [f.locate().parent.parent.resolve().as_posix()] if f.name == "libsycl.so": - library_dir += [f.locate().parent.resolve().as_posix()] + sycl_dir = f.locate().parent.resolve().as_posix() - return include_dir, library_dir + return include_dir, sycl_dir class CompilationHelper: @@ -71,6 +72,7 @@ class CompilationHelper: def __init__(self): self._library_dir = None self._include_dir = None + self._libsycl_dir = None self.libraries = ['ze_loader', 'sycl'] @cached_property @@ -78,10 +80,14 @@ def _compute_compilation_options_lazy(self): ze_root = os.getenv("ZE_PATH", default="/usr/local") include_dir = [os.path.join(ze_root, "include")] - include_dir, library_dir = find_sycl(include_dir) + library_dir = [] + include_dir, self._libsycl_dir = find_sycl(include_dir) + if self._libsycl_dir: + library_dir += [self._libsycl_dir] dirname = os.path.dirname(os.path.realpath(__file__)) include_dir += [os.path.join(dirname, "include")] + # TODO: do we need this? library_dir += [os.path.join(dirname, "lib")] self._library_dir = library_dir @@ -97,6 +103,11 @@ def include_dir(self) -> list[str]: self._compute_compilation_options_lazy return self._include_dir + @cached_property + def libsycl_dir(self) -> Optional[str]: + self._compute_compilation_options_lazy + return self._libsycl_dir + compilation_helper = CompilationHelper() @@ -110,8 +121,11 @@ def compile_module_from_src(src, name): src_path = os.path.join(tmpdir, "main.cpp") with open(src_path, "w") as f: f.write(src) + extra_compiler_args = [] + if compilation_helper.libsycl_dir: + extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir] so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir, - compilation_helper.libraries) + compilation_helper.libraries, extra_compile_args=extra_compiler_args) with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) import importlib.util