Skip to content

Commit df6320e

Browse files
committed
Reuse 'compile_module_from_src' in 'benchmark_driver.py'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 97ba60e commit df6320e

File tree

1 file changed

+1
-43
lines changed

1 file changed

+1
-43
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,12 @@
11
import os
2-
import hashlib
3-
import importlib.util
4-
import tempfile
52
from pathlib import Path
63

74
from triton.backends.compiler import GPUTarget
85
from triton.backends.driver import DriverBase
9-
from triton.runtime.cache import get_cache_manager
10-
from triton.runtime.build import _build, quiet
6+
from triton.backends.intel.driver import compile_module_from_src
117

128
import torch
139

14-
_dirname = os.getenv("ZE_PATH", default="/usr/local")
15-
16-
include_dir = [
17-
os.path.join(_dirname, "include"),
18-
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
19-
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
20-
]
21-
22-
oneapi_root = os.getenv("ONEAPI_ROOT")
23-
if oneapi_root:
24-
include_dir += [
25-
os.path.join(oneapi_root, "compiler/latest/include"),
26-
os.path.join(oneapi_root, "compiler/latest/include/sycl")
27-
]
28-
29-
library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
30-
libraries = ["ze_loader", "sycl", "torch"]
31-
32-
33-
def compile_module_from_src(src, name):
34-
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
35-
cache = get_cache_manager(key)
36-
cache_path = cache.get_file(f"{name}.so")
37-
if cache_path is None:
38-
with tempfile.TemporaryDirectory() as tmpdir:
39-
src_path = os.path.join(tmpdir, "main.cpp")
40-
with open(src_path, "w", encoding="utf-8") as f:
41-
f.write(src)
42-
with quiet():
43-
so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
44-
with open(so, "rb") as f:
45-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
46-
spec = importlib.util.spec_from_file_location(name, cache_path)
47-
mod = importlib.util.module_from_spec(spec)
48-
spec.loader.exec_module(mod)
49-
return mod
50-
51-
5210
# ------------------------
5311
# Utils
5412
# ------------------------

0 commit comments

Comments
 (0)