Skip to content

Commit 72249e2

Browse files
committed
Add sycl's rpath to rt build
1 parent 0a3b143 commit 72249e2

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

python/triton/runtime/build.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def quiet():
2323
sys.stdout, sys.stderr = old_stdout, old_stderr
2424

2525

26-
def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
26+
def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]):
2727
suffix = sysconfig.get_config_var('EXT_SUFFIX')
2828
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
2929
# try to avoid setuptools if possible
@@ -74,15 +74,14 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
7474
cc_cmd += [f'-l{lib}' for lib in libraries]
7575
cc_cmd += [f"-L{dir}" for dir in library_dirs]
7676
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
77+
cc_cmd += extra_compile_args
7778

7879
if os.getenv("VERBOSE"):
7980
print(" ".join(cc_cmd))
8081

8182
ret = subprocess.check_call(cc_cmd)
8283
if ret == 0:
8384
return so
84-
# fallback on setuptools
85-
extra_compile_args = []
8685
# extra arguments
8786
extra_link_args = []
8887
# create extension module

third_party/intel/backend/driver.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tempfile
66
from pathlib import Path
77
from functools import cached_property
8+
from typing import Optional
89

910
from triton.runtime.build import _build
1011
from triton.runtime.cache import get_cache_manager
@@ -14,36 +15,35 @@
1415
from packaging.specifiers import SpecifierSet
1516

1617

17-
def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
18+
def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
1819
"""
1920
Looks for the sycl library in known places.
2021
2122
Arguments:
2223
include_dir: list of include directories to pass to compiler.
2324
2425
Returns:
25-
enriched include_dir and library_dir.
26+
enriched include_dir and libsycl.so location.
2627
2728
Raises:
2829
AssertionError: if library was not found.
2930
"""
30-
library_dir = []
3131
include_dir = include_dir.copy()
3232
assertion_message = ("sycl headers not found, please install `icpx` compiler, "
3333
"or provide `ONEAPI_ROOT` environment "
3434
"or install `intel-sycl-rt>=2025.0.0` wheel")
3535

3636
if shutil.which("icpx"):
3737
# only `icpx` compiler knows where sycl runtime binaries and header files are
38-
return include_dir, library_dir
38+
return include_dir, None
3939

4040
oneapi_root = os.getenv("ONEAPI_ROOT")
4141
if oneapi_root:
4242
include_dir += [
4343
os.path.join(oneapi_root, "compiler/latest/include"),
4444
os.path.join(oneapi_root, "compiler/latest/include/sycl")
4545
]
46-
return include_dir, library_dir
46+
return include_dir, None
4747

4848
try:
4949
sycl_rt = importlib.metadata.metadata("intel-sycl-rt")
@@ -53,15 +53,16 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
5353
if Version(sycl_rt.get("version", "0.0.0")) in SpecifierSet("<2025.0.0a1"):
5454
raise AssertionError(assertion_message)
5555

56+
sycl_dir = None
5657
for f in importlib.metadata.files("intel-sycl-rt"):
5758
# sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
5859
# being add: include and include/sycl.
5960
if f.name == "sycl.hpp":
6061
include_dir += [f.locate().parent.parent.resolve().as_posix()]
6162
if f.name == "libsycl.so":
62-
library_dir += [f.locate().parent.resolve().as_posix()]
63+
sycl_dir = f.locate().parent.resolve().as_posix()
6364

64-
return include_dir, library_dir
65+
return include_dir, sycl_dir
6566

6667

6768
class CompilationHelper:
@@ -71,17 +72,22 @@ class CompilationHelper:
7172
def __init__(self):
7273
self._library_dir = None
7374
self._include_dir = None
75+
self._libsycl_dir = None
7476
self.libraries = ['ze_loader', 'sycl']
7577

7678
@cached_property
7779
def _compute_compilation_options_lazy(self):
7880
ze_root = os.getenv("ZE_PATH", default="/usr/local")
7981
include_dir = [os.path.join(ze_root, "include")]
8082

81-
include_dir, library_dir = find_sycl(include_dir)
83+
library_dir = []
84+
include_dir, self._libsycl_dir = find_sycl(include_dir)
85+
if self._libsycl_dir:
86+
library_dir += [self._libsycl_dir]
8287

8388
dirname = os.path.dirname(os.path.realpath(__file__))
8489
include_dir += [os.path.join(dirname, "include")]
90+
# TODO: do we need this?
8591
library_dir += [os.path.join(dirname, "lib")]
8692

8793
self._library_dir = library_dir
@@ -97,6 +103,11 @@ def include_dir(self) -> list[str]:
97103
self._compute_compilation_options_lazy
98104
return self._include_dir
99105

106+
@cached_property
107+
def libsycl_dir(self) -> Optional[str]:
108+
self._compute_compilation_options_lazy
109+
return self._libsycl_dir
110+
100111

101112
compilation_helper = CompilationHelper()
102113

@@ -110,8 +121,11 @@ def compile_module_from_src(src, name):
110121
src_path = os.path.join(tmpdir, "main.cpp")
111122
with open(src_path, "w") as f:
112123
f.write(src)
124+
extra_compiler_args = []
125+
if compilation_helper.libsycl_dir:
126+
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
113127
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
114-
compilation_helper.libraries)
128+
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
115129
with open(so, "rb") as f:
116130
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
117131
import importlib.util

0 commit comments

Comments
 (0)