Skip to content

Commit 7edd254

Browse files
committed
updaates
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 237a69b commit 7edd254

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
from triton.runtime.cache import get_cache_manager
77
from triton.runtime.build import _build, quiet
88
from triton._utils import parse_list_string
9-
from triton.backends.intel.driver import compile_module_from_src
9+
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER
1010

1111
import torch
1212

1313
# ------------------------
1414
# Utils
1515
# ------------------------
1616

17+
COMPILATION_HELPER.inject_pytorch_dep()
18+
1719

1820
class XPUUtils:
1921

third_party/intel/backend/driver.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
6868
class CompilationHelper:
6969
_library_dir: list[str]
7070
_include_dir: list[str]
71+
libraries: list[str]
72+
73+
# for benchmarks
74+
_build_with_pytorch_dep: bool = False
7175

7276
def __init__(self):
7377
self._library_dir = None
@@ -77,6 +81,12 @@ def __init__(self):
7781
if os.name != "nt":
7882
self.libraries += ["sycl"]
7983

84+
def inject_pytorch_dep(self):
85+
# must be called before any cached properties (if pytorch is needed)
86+
if self._build_with_pytorch_dep is False:
87+
self._build_with_pytorch_dep = True
88+
self.libraries += ['torch']
89+
8090
@cached_property
8191
def _compute_compilation_options_lazy(self):
8292
ze_root = os.getenv("ZE_PATH", default="/usr/local")
@@ -91,9 +101,18 @@ def _compute_compilation_options_lazy(self):
91101

92102
dirname = os.path.dirname(os.path.realpath(__file__))
93103
include_dir += [os.path.join(dirname, "include")]
94-
# TODO: do we need this?
95104
library_dir += [os.path.join(dirname, "lib")]
96105

106+
if self._build_with_pytorch_dep:
107+
import torch
108+
109+
torch_path = torch.utils.cmake_prefix_path
110+
include_dir += [
111+
os.path.join(torch_path, "../../include"),
112+
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
113+
]
114+
library_dir += [os.path.join(torch_path, "../../lib")]
115+
97116
self._library_dir = library_dir
98117
self._include_dir = include_dir
99118

@@ -113,7 +132,7 @@ def libsycl_dir(self) -> Optional[str]:
113132
return self._libsycl_dir
114133

115134

116-
compilation_helper = CompilationHelper()
135+
COMPILATION_HELPER = CompilationHelper()
117136

118137

119138
def compile_module_from_src(src, name):
@@ -127,10 +146,10 @@ def compile_module_from_src(src, name):
127146
with open(src_path, "w") as f:
128147
f.write(src)
129148
extra_compiler_args = []
130-
if compilation_helper.libsycl_dir:
131-
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
132-
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
133-
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
149+
if COMPILATION_HELPER.libsycl_dir:
150+
extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir]
151+
so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir,
152+
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
134153
with open(so, "rb") as f:
135154
cache_path = cache.put(f.read(), file_name, binary=True)
136155
import importlib.util

0 commit comments

Comments
 (0)