1414from packaging .specifiers import SpecifierSet
1515
1616
17- def find_sycl (include_dir : list [str ]) -> tuple [list [str ], list [ str ] ]:
17+ def find_sycl (include_dir : list [str ]) -> tuple [list [str ], str ]:
1818 """
1919 Looks for the sycl library in known places.
2020
2121 Arguments:
2222 include_dir: list of include directories to pass to compiler.
2323
2424 Returns:
25- enriched include_dir and library_dir .
25+ enriched include_dir and libsycl.so location .
2626
2727 Raises:
2828 AssertionError: if library was not found.
2929 """
30- library_dir = []
3130 include_dir = include_dir .copy ()
3231 assertion_message = ("sycl headers not found, please install `icpx` compiler, "
3332 "or provide `ONEAPI_ROOT` environment "
3433 "or install `intel-sycl-rt>=2025.0.0` wheel" )
3534
3635 if shutil .which ("icpx" ):
3736 # only `icpx` compiler knows where sycl runtime binaries and header files are
38- return include_dir , library_dir
37+ return include_dir , None
3938
4039 oneapi_root = os .getenv ("ONEAPI_ROOT" )
4140 if oneapi_root :
4241 include_dir += [
4342 os .path .join (oneapi_root , "compiler/latest/include" ),
4443 os .path .join (oneapi_root , "compiler/latest/include/sycl" )
4544 ]
46- return include_dir , library_dir
45+ return include_dir , None
4746
4847 try :
4948 sycl_rt = importlib .metadata .metadata ("intel-sycl-rt" )
@@ -53,15 +52,16 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
5352 if Version (sycl_rt .get ("version" , "0.0.0" )) in SpecifierSet ("<2025.0.0a1" ):
5453 raise AssertionError (assertion_message )
5554
55+ sycl_dir = None
5656 for f in importlib .metadata .files ("intel-sycl-rt" ):
5757 # sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
5858 # being add: include and include/sycl.
5959 if f .name == "sycl.hpp" :
6060 include_dir += [f .locate ().parent .parent .resolve ().as_posix ()]
6161 if f .name == "libsycl.so" :
62- library_dir += [ f .locate ().parent .resolve ().as_posix ()]
62+ sycl_dir = f .locate ().parent .resolve ().as_posix ()
6363
64- return include_dir , library_dir
64+ return include_dir , sycl_dir
6565
6666
6767class CompilationHelper :
@@ -71,17 +71,22 @@ class CompilationHelper:
7171 def __init__ (self ):
7272 self ._library_dir = None
7373 self ._include_dir = None
74+ self ._libsycl_dir = None
7475 self .libraries = ['ze_loader' , 'sycl' ]
7576
7677 @cached_property
7778 def _compute_compilation_options_lazy (self ):
7879 ze_root = os .getenv ("ZE_PATH" , default = "/usr/local" )
7980 include_dir = [os .path .join (ze_root , "include" )]
8081
81- include_dir , library_dir = find_sycl (include_dir )
82+ library_dir = []
83+ include_dir , self ._libsycl_dir = find_sycl (include_dir )
84+ if self ._libsycl_dir :
85+ library_dir += [self ._libsycl_dir ]
8286
8387 dirname = os .path .dirname (os .path .realpath (__file__ ))
8488 include_dir += [os .path .join (dirname , "include" )]
89+ # TODO: do we need this?
8590 library_dir += [os .path .join (dirname , "lib" )]
8691
8792 self ._library_dir = library_dir
@@ -97,6 +102,11 @@ def include_dir(self) -> list[str]:
97102 self ._compute_compilation_options_lazy
98103 return self ._include_dir
99104
105+ @cached_property
106+ def libsycl_dir (self ) -> list [str ]:
107+ self ._compute_compilation_options_lazy
108+ return self ._libsycl_dir
109+
100110
101111compilation_helper = CompilationHelper ()
102112
@@ -110,8 +120,11 @@ def compile_module_from_src(src, name):
110120 src_path = os .path .join (tmpdir , "main.cpp" )
111121 with open (src_path , "w" ) as f :
112122 f .write (src )
123+ extra_compiler_args = []
124+ if compilation_helper .libsycl_dir :
125+ extra_compiler_args += ['-Wl,-rpath,' + compilation_helper .libsycl_dir ]
113126 so = _build (name , src_path , tmpdir , compilation_helper .library_dir , compilation_helper .include_dir ,
114- compilation_helper .libraries )
127+ compilation_helper .libraries , extra_compile_args = extra_compiler_args )
115128 with open (so , "rb" ) as f :
116129 cache_path = cache .put (f .read (), f"{ name } .so" , binary = True )
117130 import importlib .util
0 commit comments