Skip to content

Commit 237a69b

Browse files
committed
Reuse 'compile_module_from_src' in 'benchmark_driver.py'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3e1165f commit 237a69b

File tree

1 file changed

+1
-41
lines changed

1 file changed

+1
-41
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,15 @@
11
import os
2-
import hashlib
3-
import importlib.util
4-
import tempfile
52
from pathlib import Path
63

74
from triton.backends.compiler import GPUTarget
85
from triton.backends.driver import DriverBase
96
from triton.runtime.cache import get_cache_manager
107
from triton.runtime.build import _build, quiet
118
from triton._utils import parse_list_string
9+
from triton.backends.intel.driver import compile_module_from_src
1210

1311
import torch
1412

15-
_dirname = os.getenv("ZE_PATH", default="/usr/local")
16-
17-
include_dir = [
18-
os.path.join(_dirname, "include"),
19-
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
20-
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
21-
]
22-
23-
oneapi_root = os.getenv("ONEAPI_ROOT")
24-
if oneapi_root:
25-
include_dir += [
26-
os.path.join(oneapi_root, "compiler/latest/include"),
27-
os.path.join(oneapi_root, "compiler/latest/include/sycl")
28-
]
29-
30-
library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
31-
libraries = ["ze_loader", "sycl", "torch"]
32-
33-
34-
def compile_module_from_src(src, name):
35-
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
36-
cache = get_cache_manager(key)
37-
cache_path = cache.get_file(f"{name}.so")
38-
if cache_path is None:
39-
with tempfile.TemporaryDirectory() as tmpdir:
40-
src_path = os.path.join(tmpdir, "main.cpp")
41-
with open(src_path, "w", encoding="utf-8") as f:
42-
f.write(src)
43-
with quiet():
44-
so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
45-
with open(so, "rb") as f:
46-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
47-
spec = importlib.util.spec_from_file_location(name, cache_path)
48-
mod = importlib.util.module_from_spec(spec)
49-
spec.loader.exec_module(mod)
50-
return mod
51-
52-
5313
# ------------------------
5414
# Utils
5515
# ------------------------

0 commit comments

Comments
 (0)