@@ -67,6 +67,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
6767class CompilationHelper :
6868 _library_dir : list [str ]
6969 _include_dir : list [str ]
70+ libraries : list [str ]
71+
72+ # for benchmarks
73+ _build_with_pytorch_dep : bool = False
7074
7175 def __init__ (self ):
7276 self ._library_dir = None
@@ -76,6 +80,12 @@ def __init__(self):
7680 if os .name != "nt" :
7781 self .libraries += ["sycl" ]
7882
83+ def inject_pytorch_dep (self ):
84+ # must be called before any cached properties (if pytorch is needed)
85+ if self ._build_with_pytorch_dep is False :
86+ self ._build_with_pytorch_dep = True
87+ self .libraries += ['torch' ]
88+
7989 @cached_property
8090 def _compute_compilation_options_lazy (self ):
8191 ze_root = os .getenv ("ZE_PATH" , default = "/usr/local" )
@@ -90,9 +100,18 @@ def _compute_compilation_options_lazy(self):
90100
91101 dirname = os .path .dirname (os .path .realpath (__file__ ))
92102 include_dir += [os .path .join (dirname , "include" )]
93- # TODO: do we need this?
94103 library_dir += [os .path .join (dirname , "lib" )]
95104
105+ if self ._build_with_pytorch_dep :
106+ import torch
107+
108+ torch_path = torch .utils .cmake_prefix_path
109+ include_dir += [
110+ os .path .join (torch_path , "../../include" ),
111+ os .path .join (torch_path , "../../include/torch/csrc/api/include" ),
112+ ]
113+ library_dir += [os .path .join (torch_path , "../../lib" )]
114+
96115 self ._library_dir = library_dir
97116 self ._include_dir = include_dir
98117
@@ -112,7 +131,7 @@ def libsycl_dir(self) -> Optional[str]:
112131 return self ._libsycl_dir
113132
114133
115- compilation_helper = CompilationHelper ()
134+ COMPILATION_HELPER = CompilationHelper ()
116135
117136
118137def compile_module_from_src (src , name ):
@@ -126,10 +145,10 @@ def compile_module_from_src(src, name):
126145 with open (src_path , "w" ) as f :
127146 f .write (src )
128147 extra_compiler_args = []
129- if compilation_helper .libsycl_dir :
130- extra_compiler_args += ['-Wl,-rpath,' + compilation_helper .libsycl_dir ]
131- so = _build (name , src_path , tmpdir , compilation_helper .library_dir , compilation_helper .include_dir ,
132- compilation_helper .libraries , extra_compile_args = extra_compiler_args )
148+ if COMPILATION_HELPER .libsycl_dir :
149+ extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER .libsycl_dir ]
150+ so = _build (name , src_path , tmpdir , COMPILATION_HELPER .library_dir , COMPILATION_HELPER .include_dir ,
151+ COMPILATION_HELPER .libraries , extra_compile_args = extra_compiler_args )
133152 with open (so , "rb" ) as f :
134153 cache_path = cache .put (f .read (), file_name , binary = True )
135154 import importlib .util
0 commit comments