|
1 | 1 | import os |
2 | | -import hashlib |
3 | | -import importlib.util |
4 | | -import tempfile |
5 | 2 | from pathlib import Path |
6 | 3 |
|
7 | 4 | from triton.backends.compiler import GPUTarget |
8 | 5 | from triton.backends.driver import DriverBase |
9 | 6 | from triton.runtime.cache import get_cache_manager |
10 | 7 | from triton.runtime.build import _build, quiet |
11 | 8 | from triton._utils import parse_list_string |
| 9 | +from triton.backends.intel.driver import compile_module_from_src |
12 | 10 |
|
13 | 11 | import torch |
14 | 12 |
|
15 | | -_dirname = os.getenv("ZE_PATH", default="/usr/local") |
16 | | - |
17 | | -include_dir = [ |
18 | | - os.path.join(_dirname, "include"), |
19 | | - os.path.join(torch.utils.cmake_prefix_path, "../../include"), |
20 | | - os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include") |
21 | | -] |
22 | | - |
23 | | -oneapi_root = os.getenv("ONEAPI_ROOT") |
24 | | -if oneapi_root: |
25 | | - include_dir += [ |
26 | | - os.path.join(oneapi_root, "compiler/latest/include"), |
27 | | - os.path.join(oneapi_root, "compiler/latest/include/sycl") |
28 | | - ] |
29 | | - |
30 | | -library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")] |
31 | | -libraries = ["ze_loader", "sycl", "torch"] |
32 | | - |
33 | | - |
34 | | -def compile_module_from_src(src, name): |
35 | | - key = hashlib.sha256(src.encode("utf-8")).hexdigest() |
36 | | - cache = get_cache_manager(key) |
37 | | - cache_path = cache.get_file(f"{name}.so") |
38 | | - if cache_path is None: |
39 | | - with tempfile.TemporaryDirectory() as tmpdir: |
40 | | - src_path = os.path.join(tmpdir, "main.cpp") |
41 | | - with open(src_path, "w", encoding="utf-8") as f: |
42 | | - f.write(src) |
43 | | - with quiet(): |
44 | | - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) |
45 | | - with open(so, "rb") as f: |
46 | | - cache_path = cache.put(f.read(), f"{name}.so", binary=True) |
47 | | - spec = importlib.util.spec_from_file_location(name, cache_path) |
48 | | - mod = importlib.util.module_from_spec(spec) |
49 | | - spec.loader.exec_module(mod) |
50 | | - return mod |
51 | | - |
52 | | - |
53 | 13 | # ------------------------ |
54 | 14 | # Utils |
55 | 15 | # ------------------------ |
|
0 commit comments