Skip to content

Commit a9a4288

Browse files
committed
Add sycl's rpath to rt build
1 parent 0df7d80 commit a9a4288

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

python/triton/runtime/build.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def quiet():
2323
sys.stdout, sys.stderr = old_stdout, old_stderr
2424

2525

26-
def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
26+
def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]):
2727
suffix = sysconfig.get_config_var('EXT_SUFFIX')
2828
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
2929
# try to avoid setuptools if possible
@@ -74,15 +74,14 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
7474
cc_cmd += [f'-l{lib}' for lib in libraries]
7575
cc_cmd += [f"-L{dir}" for dir in library_dirs]
7676
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
77+
cc_cmd += extra_compile_args
7778

7879
if os.getenv("VERBOSE"):
7980
print(" ".join(cc_cmd))
8081

8182
ret = subprocess.check_call(cc_cmd)
8283
if ret == 0:
8384
return so
84-
# fallback on setuptools
85-
extra_compile_args = []
8685
# extra arguments
8786
extra_link_args = []
8887
# create extension module

third_party/intel/backend/driver.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,35 @@
1414
from packaging.specifiers import SpecifierSet
1515

1616

17-
def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
17+
def find_sycl(include_dir: list[str]) -> tuple[list[str], str]:
1818
"""
1919
Looks for the sycl library in known places.
2020
2121
Arguments:
2222
include_dir: list of include directories to pass to compiler.
2323
2424
Returns:
25-
enriched include_dir and library_dir.
25+
enriched include_dir and libsycl.so location.
2626
2727
Raises:
2828
AssertionError: if library was not found.
2929
"""
30-
library_dir = []
3130
include_dir = include_dir.copy()
3231
assertion_message = ("sycl headers not found, please install `icpx` compiler, "
3332
"or provide `ONEAPI_ROOT` environment "
3433
"or install `intel-sycl-rt>=2025.0.0` wheel")
3534

3635
if shutil.which("icpx"):
3736
# only `icpx` compiler knows where sycl runtime binaries and header files are
38-
return include_dir, library_dir
37+
return include_dir, None
3938

4039
oneapi_root = os.getenv("ONEAPI_ROOT")
4140
if oneapi_root:
4241
include_dir += [
4342
os.path.join(oneapi_root, "compiler/latest/include"),
4443
os.path.join(oneapi_root, "compiler/latest/include/sycl")
4544
]
46-
return include_dir, library_dir
45+
return include_dir, None
4746

4847
try:
4948
sycl_rt = importlib.metadata.metadata("intel-sycl-rt")
@@ -53,15 +52,16 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
5352
if Version(sycl_rt.get("version", "0.0.0")) in SpecifierSet("<2025.0.0a1"):
5453
raise AssertionError(assertion_message)
5554

55+
sycl_dir = None
5656
for f in importlib.metadata.files("intel-sycl-rt"):
5757
# sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
5858
# being add: include and include/sycl.
5959
if f.name == "sycl.hpp":
6060
include_dir += [f.locate().parent.parent.resolve().as_posix()]
6161
if f.name == "libsycl.so":
62-
library_dir += [f.locate().parent.resolve().as_posix()]
62+
sycl_dir = f.locate().parent.resolve().as_posix()
6363

64-
return include_dir, library_dir
64+
return include_dir, sycl_dir
6565

6666

6767
class CompilationHelper:
@@ -71,17 +71,22 @@ class CompilationHelper:
7171
def __init__(self):
7272
self._library_dir = None
7373
self._include_dir = None
74+
self._libsycl_dir = None
7475
self.libraries = ['ze_loader', 'sycl']
7576

7677
@cached_property
7778
def _compute_compilation_options_lazy(self):
7879
ze_root = os.getenv("ZE_PATH", default="/usr/local")
7980
include_dir = [os.path.join(ze_root, "include")]
8081

81-
include_dir, library_dir = find_sycl(include_dir)
82+
library_dir = []
83+
include_dir, self._libsycl_dir = find_sycl(include_dir)
84+
if self._libsycl_dir:
85+
library_dir += [self._libsycl_dir]
8286

8387
dirname = os.path.dirname(os.path.realpath(__file__))
8488
include_dir += [os.path.join(dirname, "include")]
89+
# TODO: do we need this?
8590
library_dir += [os.path.join(dirname, "lib")]
8691

8792
self._library_dir = library_dir
@@ -97,6 +102,11 @@ def include_dir(self) -> list[str]:
97102
self._compute_compilation_options_lazy
98103
return self._include_dir
99104

105+
@cached_property
106+
def libsycl_dir(self) -> list[str]:
107+
self._compute_compilation_options_lazy
108+
return self._libsycl_dir
109+
100110

101111
compilation_helper = CompilationHelper()
102112

@@ -110,8 +120,11 @@ def compile_module_from_src(src, name):
110120
src_path = os.path.join(tmpdir, "main.cpp")
111121
with open(src_path, "w") as f:
112122
f.write(src)
123+
extra_compiler_args = []
124+
if compilation_helper.libsycl_dir:
125+
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
113126
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
114-
compilation_helper.libraries)
127+
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
115128
with open(so, "rb") as f:
116129
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
117130
import importlib.util

0 commit comments

Comments
 (0)