1010from triton .runtime .build import _build , quiet
1111
1212import torch
13- import intel_extension_for_pytorch
13+
14+ from .benchmark_testing import USE_IPEX_OPTION
1415
1516_dirname = os .getenv ("ZE_PATH" , default = "/usr/local" )
1617
1718include_dir = [
1819 os .path .join (_dirname , "include" ),
1920 os .path .join (torch .utils .cmake_prefix_path , "../../include" ),
20- os .path .join (torch .utils .cmake_prefix_path , "../../include/torch/csrc/api/include" ),
21- os .path .join (intel_extension_for_pytorch .cmake_prefix_path , "../../include" )
21+ os .path .join (torch .utils .cmake_prefix_path , "../../include/torch/csrc/api/include" )
2222]
2323
2424oneapi_root = os .getenv ("ONEAPI_ROOT" )
2828 os .path .join (oneapi_root , "compiler/latest/include/sycl" )
2929 ]
3030
31- library_dir = [
32- os .path .join (_dirname , "lib" ),
33- os .path .join (torch .utils .cmake_prefix_path , "../../lib" ),
34- os .path .join (intel_extension_for_pytorch .cmake_prefix_path , "../../lib" )
35- ]
36- libraries = ["ze_loader" , "sycl" , "torch" , "intel-ext-pt-gpu" ]
31+ library_dir = [os .path .join (_dirname , "lib" ), os .path .join (torch .utils .cmake_prefix_path , "../../lib" )]
32+ libraries = ["ze_loader" , "sycl" , "torch" ]
33+
34+ if USE_IPEX_OPTION :
35+ import intel_extension_for_pytorch
36+
37+ include_dir .append (os .path .join (intel_extension_for_pytorch .cmake_prefix_path , "../../include" ))
38+ library_dir .append (os .path .join (intel_extension_for_pytorch .cmake_prefix_path , "../../lib" ))
39+ libraries .append ("intel-ext-pt-gpu" )
3740
3841
3942def compile_module_from_src (src , name ):
@@ -141,6 +144,14 @@ def format_of(ty):
141144 fmt = "iiiOOOOOO" + args_format
142145 args_list = ", " + ", " .join (f"&_arg{ i } " for i , ty in signature .items ()) if len (signature ) > 0 else ""
143146
147+ record_function_header = "#include <ATen/record_function.h>"
148+ ipex_header = ""
149+ xpu_profiler_record = ""
150+ if USE_IPEX_OPTION :
151+ record_function_header = "#include <torch/extension.h>"
152+ ipex_header = "#include <ipex.h>"
153+ xpu_profiler_record = "xpu::profiler_record(kernel_name, event);"
154+
144155 # generate glue code
145156 src = f"""
146157 #include <cstddef>
@@ -149,8 +160,8 @@ def format_of(ty):
149160 #include <iomanip>
150161 #include <level_zero/ze_api.h>
151162 #include <sycl/sycl.hpp>
152- #include <torch/extension.h>
153- #include <ipex.h>
163+ { record_function_header }
164+ { ipex_header }
154165
155166 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
156167 #include <Python.h>
@@ -291,7 +302,7 @@ def format_of(ty):
291302 }}
292303 }};
293304 auto event = stream.submit(cgf);
294- xpu::profiler_record(kernel_name, event);
305+ { xpu_profiler_record }
295306 }}
296307// end sycl
297308 static PyObject* launch(PyObject* self, PyObject* args) {{
0 commit comments