@@ -68,6 +68,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
6868class CompilationHelper :
6969 _library_dir : list [str ]
7070 _include_dir : list [str ]
71+ libraries : list [str ]
72+
73+ # for benchmarks
74+ _build_with_pytorch_dep : bool = False
7175
7276 def __init__ (self ):
7377 self ._library_dir = None
@@ -77,6 +81,12 @@ def __init__(self):
7781 if os .name != "nt" :
7882 self .libraries += ["sycl" ]
7983
84+ def inject_pytorch_dep (self ):
85+ # must be called before any cached properties (if pytorch is needed)
86+ if self ._build_with_pytorch_dep is False :
87+ self ._build_with_pytorch_dep = True
88+ self .libraries += ['torch' ]
89+
8090 @cached_property
8191 def _compute_compilation_options_lazy (self ):
8292 ze_root = os .getenv ("ZE_PATH" , default = "/usr/local" )
@@ -91,9 +101,18 @@ def _compute_compilation_options_lazy(self):
91101
92102 dirname = os .path .dirname (os .path .realpath (__file__ ))
93103 include_dir += [os .path .join (dirname , "include" )]
94- # TODO: do we need this?
95104 library_dir += [os .path .join (dirname , "lib" )]
96105
106+ if self ._build_with_pytorch_dep :
107+ import torch
108+
109+ torch_path = torch .utils .cmake_prefix_path
110+ include_dir += [
111+ os .path .join (torch_path , "../../include" ),
112+ os .path .join (torch_path , "../../include/torch/csrc/api/include" ),
113+ ]
114+ library_dir += [os .path .join (torch_path , "../../lib" )]
115+
97116 self ._library_dir = library_dir
98117 self ._include_dir = include_dir
99118
@@ -113,7 +132,7 @@ def libsycl_dir(self) -> Optional[str]:
113132 return self ._libsycl_dir
114133
115134
116- compilation_helper = CompilationHelper ()
135+ COMPILATION_HELPER = CompilationHelper ()
117136
118137
119138def compile_module_from_src (src , name ):
@@ -127,10 +146,10 @@ def compile_module_from_src(src, name):
127146 with open (src_path , "w" ) as f :
128147 f .write (src )
129148 extra_compiler_args = []
130- if compilation_helper .libsycl_dir :
131- extra_compiler_args += ['-Wl,-rpath,' + compilation_helper .libsycl_dir ]
132- so = _build (name , src_path , tmpdir , compilation_helper .library_dir , compilation_helper .include_dir ,
133- compilation_helper .libraries , extra_compile_args = extra_compiler_args )
149+ if COMPILATION_HELPER .libsycl_dir :
150+ extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER .libsycl_dir ]
151+ so = _build (name , src_path , tmpdir , COMPILATION_HELPER .library_dir , COMPILATION_HELPER .include_dir ,
152+ COMPILATION_HELPER .libraries , extra_compile_args = extra_compiler_args )
134153 with open (so , "rb" ) as f :
135154 cache_path = cache .put (f .read (), file_name , binary = True )
136155 import importlib .util
0 commit comments