Skip to content

Commit 524d51c

Browse files
committed
updaates
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent df6320e commit 524d51c

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
from triton.backends.compiler import GPUTarget
55
from triton.backends.driver import DriverBase
6-
from triton.backends.intel.driver import compile_module_from_src
6+
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER
77

88
import torch
99

1010
# ------------------------
1111
# Utils
1212
# ------------------------
1313

14+
COMPILATION_HELPER.inject_pytorch_dep()
15+
1416

1517
class XPUUtils:
1618

third_party/intel/backend/driver.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
6767
class CompilationHelper:
6868
_library_dir: list[str]
6969
_include_dir: list[str]
70+
libraries: list[str]
71+
72+
# for benchmarks
73+
_build_with_pytorch_dep: bool = False
7074

7175
def __init__(self):
7276
self._library_dir = None
@@ -76,6 +80,12 @@ def __init__(self):
7680
if os.name != "nt":
7781
self.libraries += ["sycl"]
7882

83+
def inject_pytorch_dep(self):
84+
# must be called before any cached properties (if pytorch is needed)
85+
if self._build_with_pytorch_dep is False:
86+
self._build_with_pytorch_dep = True
87+
self.libraries += ['torch']
88+
7989
@cached_property
8090
def _compute_compilation_options_lazy(self):
8191
ze_root = os.getenv("ZE_PATH", default="/usr/local")
@@ -90,9 +100,18 @@ def _compute_compilation_options_lazy(self):
90100

91101
dirname = os.path.dirname(os.path.realpath(__file__))
92102
include_dir += [os.path.join(dirname, "include")]
93-
# TODO: do we need this?
94103
library_dir += [os.path.join(dirname, "lib")]
95104

105+
if self._build_with_pytorch_dep:
106+
import torch
107+
108+
torch_path = torch.utils.cmake_prefix_path
109+
include_dir += [
110+
os.path.join(torch_path, "../../include"),
111+
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
112+
]
113+
library_dir += [os.path.join(torch_path, "../../lib")]
114+
96115
self._library_dir = library_dir
97116
self._include_dir = include_dir
98117

@@ -112,7 +131,7 @@ def libsycl_dir(self) -> Optional[str]:
112131
return self._libsycl_dir
113132

114133

115-
compilation_helper = CompilationHelper()
134+
COMPILATION_HELPER = CompilationHelper()
116135

117136

118137
def compile_module_from_src(src, name):
@@ -126,10 +145,10 @@ def compile_module_from_src(src, name):
126145
with open(src_path, "w") as f:
127146
f.write(src)
128147
extra_compiler_args = []
129-
if compilation_helper.libsycl_dir:
130-
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
131-
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
132-
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
148+
if COMPILATION_HELPER.libsycl_dir:
149+
extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir]
150+
so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir,
151+
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
133152
with open(so, "rb") as f:
134153
cache_path = cache.put(f.read(), file_name, binary=True)
135154
import importlib.util

0 commit comments

Comments
 (0)