55import tempfile
66from pathlib import Path
77from functools import cached_property
8+ from typing import Optional
89
910from triton .runtime .build import _build
1011from triton .runtime .cache import get_cache_manager
1415from packaging .specifiers import SpecifierSet
1516
1617
17- def find_sycl (include_dir : list [str ]) -> tuple [list [str ], list [str ]]:
18+ def find_sycl (include_dir : list [str ]) -> tuple [list [str ], Optional [str ]]:
1819 """
1920 Looks for the sycl library in known places.
2021
2122 Arguments:
2223 include_dir: list of include directories to pass to compiler.
2324
2425 Returns:
25- enriched include_dir and library_dir .
26+ enriched include_dir and libsycl.so location .
2627
2728 Raises:
2829 AssertionError: if library was not found.
2930 """
30- library_dir = []
3131 include_dir = include_dir .copy ()
3232 assertion_message = ("sycl headers not found, please install `icpx` compiler, "
3333 "or provide `ONEAPI_ROOT` environment "
3434 "or install `intel-sycl-rt>=2025.0.0` wheel" )
3535
3636 if shutil .which ("icpx" ):
3737 # only `icpx` compiler knows where sycl runtime binaries and header files are
38- return include_dir , library_dir
38+ return include_dir , None
3939
4040 oneapi_root = os .getenv ("ONEAPI_ROOT" )
4141 if oneapi_root :
4242 include_dir += [
4343 os .path .join (oneapi_root , "compiler/latest/include" ),
4444 os .path .join (oneapi_root , "compiler/latest/include/sycl" )
4545 ]
46- return include_dir , library_dir
46+ return include_dir , None
4747
4848 try :
4949 sycl_rt = importlib .metadata .metadata ("intel-sycl-rt" )
@@ -53,15 +53,16 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
5353 if Version (sycl_rt .get ("version" , "0.0.0" )) in SpecifierSet ("<2025.0.0a1" ):
5454 raise AssertionError (assertion_message )
5555
56+ sycl_dir = None
5657 for f in importlib .metadata .files ("intel-sycl-rt" ):
5758 # sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
5859 # being add: include and include/sycl.
5960 if f .name == "sycl.hpp" :
6061 include_dir += [f .locate ().parent .parent .resolve ().as_posix ()]
6162 if f .name == "libsycl.so" :
62- library_dir += [ f .locate ().parent .resolve ().as_posix ()]
63+ sycl_dir = f .locate ().parent .resolve ().as_posix ()
6364
64- return include_dir , library_dir
65+ return include_dir , sycl_dir
6566
6667
6768class CompilationHelper :
@@ -71,17 +72,22 @@ class CompilationHelper:
7172 def __init__ (self ):
7273 self ._library_dir = None
7374 self ._include_dir = None
75+ self ._libsycl_dir = None
7476 self .libraries = ['ze_loader' , 'sycl' ]
7577
7678 @cached_property
7779 def _compute_compilation_options_lazy (self ):
7880 ze_root = os .getenv ("ZE_PATH" , default = "/usr/local" )
7981 include_dir = [os .path .join (ze_root , "include" )]
8082
81- include_dir , library_dir = find_sycl (include_dir )
83+ library_dir = []
84+ include_dir , self ._libsycl_dir = find_sycl (include_dir )
85+ if self ._libsycl_dir :
86+ library_dir += [self ._libsycl_dir ]
8287
8388 dirname = os .path .dirname (os .path .realpath (__file__ ))
8489 include_dir += [os .path .join (dirname , "include" )]
90+ # TODO: do we need this?
8591 library_dir += [os .path .join (dirname , "lib" )]
8692
8793 self ._library_dir = library_dir
@@ -97,6 +103,11 @@ def include_dir(self) -> list[str]:
97103 self ._compute_compilation_options_lazy
98104 return self ._include_dir
99105
106+ @cached_property
107+ def libsycl_dir (self ) -> Optional [str ]:
108+ self ._compute_compilation_options_lazy
109+ return self ._libsycl_dir
110+
100111
101112compilation_helper = CompilationHelper ()
102113
@@ -110,8 +121,11 @@ def compile_module_from_src(src, name):
110121 src_path = os .path .join (tmpdir , "main.cpp" )
111122 with open (src_path , "w" ) as f :
112123 f .write (src )
124+ extra_compiler_args = []
125+ if compilation_helper .libsycl_dir :
126+ extra_compiler_args += ['-Wl,-rpath,' + compilation_helper .libsycl_dir ]
113127 so = _build (name , src_path , tmpdir , compilation_helper .library_dir , compilation_helper .include_dir ,
114- compilation_helper .libraries )
128+ compilation_helper .libraries , extra_compile_args = extra_compiler_args )
115129 with open (so , "rb" ) as f :
116130 cache_path = cache .put (f .read (), f"{ name } .so" , binary = True )
117131 import importlib .util
0 commit comments