Skip to content

Commit e4e0905

Browse files
authored
Reuse compile_module_from_src func in benchmark_driver.py (#3051)
Part of #2540 --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7d03355 commit e4e0905

File tree

2 files changed

+28
-49
lines changed

2 files changed

+28
-49
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,19 @@
11
import os
2-
import hashlib
3-
import importlib.util
4-
import tempfile
52
from pathlib import Path
63

74
from triton.backends.compiler import GPUTarget
85
from triton.backends.driver import DriverBase
9-
from triton.runtime.cache import get_cache_manager
10-
from triton.runtime.build import _build, quiet
116
from triton._utils import parse_list_string
7+
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER
128

139
import torch
1410

15-
_dirname = os.getenv("ZE_PATH", default="/usr/local")
16-
17-
include_dir = [
18-
os.path.join(_dirname, "include"),
19-
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
20-
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
21-
]
22-
23-
oneapi_root = os.getenv("ONEAPI_ROOT")
24-
if oneapi_root:
25-
include_dir += [
26-
os.path.join(oneapi_root, "compiler/latest/include"),
27-
os.path.join(oneapi_root, "compiler/latest/include/sycl")
28-
]
29-
30-
library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
31-
libraries = ["ze_loader", "sycl", "torch"]
32-
33-
34-
def compile_module_from_src(src, name):
35-
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
36-
cache = get_cache_manager(key)
37-
cache_path = cache.get_file(f"{name}.so")
38-
if cache_path is None:
39-
with tempfile.TemporaryDirectory() as tmpdir:
40-
src_path = os.path.join(tmpdir, "main.cpp")
41-
with open(src_path, "w", encoding="utf-8") as f:
42-
f.write(src)
43-
with quiet():
44-
so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
45-
with open(so, "rb") as f:
46-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
47-
spec = importlib.util.spec_from_file_location(name, cache_path)
48-
mod = importlib.util.module_from_spec(spec)
49-
spec.loader.exec_module(mod)
50-
return mod
51-
52-
5311
# ------------------------
5412
# Utils
5513
# ------------------------
5614

15+
COMPILATION_HELPER.inject_pytorch_dep()
16+
5717

5818
class XPUUtils:
5919

third_party/intel/backend/driver.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
6868
class CompilationHelper:
6969
_library_dir: list[str]
7070
_include_dir: list[str]
71+
libraries: list[str]
72+
73+
# for benchmarks
74+
_build_with_pytorch_dep: bool = False
7175

7276
def __init__(self):
7377
self._library_dir = None
@@ -77,6 +81,12 @@ def __init__(self):
7781
if os.name != "nt":
7882
self.libraries += ["sycl"]
7983

84+
def inject_pytorch_dep(self):
85+
# must be called before any cached properties (if pytorch is needed)
86+
if self._build_with_pytorch_dep is False:
87+
self._build_with_pytorch_dep = True
88+
self.libraries += ['torch']
89+
8090
@cached_property
8191
def _compute_compilation_options_lazy(self):
8292
ze_root = os.getenv("ZE_PATH", default="/usr/local")
@@ -91,9 +101,18 @@ def _compute_compilation_options_lazy(self):
91101

92102
dirname = os.path.dirname(os.path.realpath(__file__))
93103
include_dir += [os.path.join(dirname, "include")]
94-
# TODO: do we need this?
95104
library_dir += [os.path.join(dirname, "lib")]
96105

106+
if self._build_with_pytorch_dep:
107+
import torch
108+
109+
torch_path = torch.utils.cmake_prefix_path
110+
include_dir += [
111+
os.path.join(torch_path, "../../include"),
112+
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
113+
]
114+
library_dir += [os.path.join(torch_path, "../../lib")]
115+
97116
self._library_dir = library_dir
98117
self._include_dir = include_dir
99118

@@ -113,7 +132,7 @@ def libsycl_dir(self) -> Optional[str]:
113132
return self._libsycl_dir
114133

115134

116-
compilation_helper = CompilationHelper()
135+
COMPILATION_HELPER = CompilationHelper()
117136

118137

119138
def compile_module_from_src(src, name):
@@ -127,10 +146,10 @@ def compile_module_from_src(src, name):
127146
with open(src_path, "w") as f:
128147
f.write(src)
129148
extra_compiler_args = []
130-
if compilation_helper.libsycl_dir:
131-
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
132-
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
133-
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
149+
if COMPILATION_HELPER.libsycl_dir:
150+
extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir]
151+
so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir,
152+
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
134153
with open(so, "rb") as f:
135154
cache_path = cache.put(f.read(), file_name, binary=True)
136155
import importlib.util

0 commit comments

Comments
 (0)