1414from packaging .specifiers import SpecifierSet
1515
1616
17- def find_sycl (include_dir : list [str ]) -> tuple [list [str ], list [ str ] ]:
17+ def find_sycl (include_dir : list [str ]) -> tuple [list [str ], str ]:
1818 """
1919 Looks for the sycl library in known places.
2020
2121 Arguments:
2222 include_dir: list of include directories to pass to compiler.
2323
2424 Returns:
25- enriched include_dir and library_dir .
25+ enriched include_dir and libsycl.so location .
2626
2727 Raises:
2828 AssertionError: if library was not found.
2929 """
30- library_dir = []
3130 include_dir = include_dir .copy ()
3231 assertion_message = ("sycl headers not found, please install `icpx` compiler, "
3332 "or provide `ONEAPI_ROOT` environment "
3433 "or install `intel-sycl-rt>=2025.0.0` wheel" )
3534
3635 if shutil .which ("icpx" ):
3736 # only `icpx` compiler knows where sycl runtime binaries and header files are
38- return include_dir , library_dir
37+ return include_dir , None
3938
4039 oneapi_root = os .getenv ("ONEAPI_ROOT" )
4140 if oneapi_root :
4241 include_dir += [
4342 os .path .join (oneapi_root , "compiler/latest/include" ),
4443 os .path .join (oneapi_root , "compiler/latest/include/sycl" )
4544 ]
46- return include_dir , library_dir
45+ return include_dir , None
4746
4847 try :
4948 sycl_rt = importlib .metadata .metadata ("intel-sycl-rt" )
@@ -53,15 +52,17 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
5352 if Version (sycl_rt .get ("version" , "0.0.0" )) in SpecifierSet ("<2025.0.0a1" ):
5453 raise AssertionError (assertion_message )
5554
55+ sycl_dir = None
5656 for f in importlib .metadata .files ("intel-sycl-rt" ):
5757 # sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
5858 # being add: include and include/sycl.
5959 if f .name == "sycl.hpp" :
6060 include_dir += [f .locate ().parent .parent .resolve ().as_posix ()]
6161 if f .name == "libsycl.so" :
62- library_dir += [f .locate ().parent .resolve ().as_posix ()]
62+ sycl_dir = f .locate ().parent .resolve ().as_posix ()
63+ print (sycl_dir )
6364
64- return include_dir , library_dir
65+ return include_dir , sycl_dir
6566
6667
6768class CompilationHelper :
@@ -71,17 +72,22 @@ class CompilationHelper:
7172 def __init__ (self ):
7273 self ._library_dir = None
7374 self ._include_dir = None
75+ self ._libsycl_dir = None
7476 self .libraries = ['ze_loader' , 'sycl' ]
7577
7678 @cached_property
7779 def _compute_compilation_options_lazy (self ):
7880 ze_root = os .getenv ("ZE_PATH" , default = "/usr/local" )
7981 include_dir = [os .path .join (ze_root , "include" )]
8082
81- include_dir , library_dir = find_sycl (include_dir )
83+ library_dir = []
84+ include_dir , self ._libsycl_dir = find_sycl (include_dir )
85+ if self ._libsycl_dir :
86+ library_dir += [self ._libsycl_dir ]
8287
8388 dirname = os .path .dirname (os .path .realpath (__file__ ))
8489 include_dir += [os .path .join (dirname , "include" )]
90+ # TODO: do we need this?
8591 library_dir += [os .path .join (dirname , "lib" )]
8692
8793 self ._library_dir = library_dir
@@ -97,6 +103,11 @@ def include_dir(self) -> list[str]:
97103 self ._compute_compilation_options_lazy
98104 return self ._include_dir
99105
106+ @cached_property
107+ def libsycl_dir (self ) -> list [str ]:
108+ self ._compute_compilation_options_lazy
109+ return self ._libsycl_dir
110+
100111
101112compilation_helper = CompilationHelper ()
102113
@@ -110,8 +121,12 @@ def compile_module_from_src(src, name):
110121 src_path = os .path .join (tmpdir , "main.cpp" )
111122 with open (src_path , "w" ) as f :
112123 f .write (src )
124+ extra_compiler_args = []
125+ if compilation_helper .libsycl_dir :
126+ extra_compiler_args += ['-Wl,-rpath,' + compilation_helper .libsycl_dir ]
127+ print ("EXTRA compiler args: " , extra_compiler_args )
113128 so = _build (name , src_path , tmpdir , compilation_helper .library_dir , compilation_helper .include_dir ,
114- compilation_helper .libraries )
129+ compilation_helper .libraries , extra_compile_args = extra_compiler_args )
115130 with open (so , "rb" ) as f :
116131 cache_path = cache .put (f .read (), f"{ name } .so" , binary = True )
117132 import importlib .util
0 commit comments