Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 3 additions & 43 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,19 @@
import os
import hashlib
import importlib.util
import tempfile
from pathlib import Path

from triton.backends.compiler import GPUTarget
from triton.backends.driver import DriverBase
from triton.runtime.cache import get_cache_manager
from triton.runtime.build import _build, quiet
from triton._utils import parse_list_string
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER

import torch

_dirname = os.getenv("ZE_PATH", default="/usr/local")

include_dir = [
os.path.join(_dirname, "include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
]

oneapi_root = os.getenv("ONEAPI_ROOT")
if oneapi_root:
include_dir += [
os.path.join(oneapi_root, "compiler/latest/include"),
os.path.join(oneapi_root, "compiler/latest/include/sycl")
]

library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
libraries = ["ze_loader", "sycl", "torch"]


def compile_module_from_src(src, name):
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.cpp")
with open(src_path, "w", encoding="utf-8") as f:
f.write(src)
with quiet():
so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
spec = importlib.util.spec_from_file_location(name, cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


# ------------------------
# Utils
# ------------------------

COMPILATION_HELPER.inject_pytorch_dep()


class XPUUtils:

Expand Down
31 changes: 25 additions & 6 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
class CompilationHelper:
_library_dir: list[str]
_include_dir: list[str]
libraries: list[str]

# for benchmarks
_build_with_pytorch_dep: bool = False

def __init__(self):
self._library_dir = None
Expand All @@ -77,6 +81,12 @@ def __init__(self):
if os.name != "nt":
self.libraries += ["sycl"]

def inject_pytorch_dep(self):
# must be called before any cached properties (if pytorch is needed)
if self._build_with_pytorch_dep is False:
self._build_with_pytorch_dep = True
self.libraries += ['torch']

@cached_property
def _compute_compilation_options_lazy(self):
ze_root = os.getenv("ZE_PATH", default="/usr/local")
Expand All @@ -91,9 +101,18 @@ def _compute_compilation_options_lazy(self):

dirname = os.path.dirname(os.path.realpath(__file__))
include_dir += [os.path.join(dirname, "include")]
# TODO: do we need this?
library_dir += [os.path.join(dirname, "lib")]

if self._build_with_pytorch_dep:
import torch

torch_path = torch.utils.cmake_prefix_path
include_dir += [
os.path.join(torch_path, "../../include"),
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
]
library_dir += [os.path.join(torch_path, "../../lib")]

self._library_dir = library_dir
self._include_dir = include_dir

Expand All @@ -113,7 +132,7 @@ def libsycl_dir(self) -> Optional[str]:
return self._libsycl_dir


compilation_helper = CompilationHelper()
COMPILATION_HELPER = CompilationHelper()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To mark that the variable is global



def compile_module_from_src(src, name):
Expand All @@ -127,10 +146,10 @@ def compile_module_from_src(src, name):
with open(src_path, "w") as f:
f.write(src)
extra_compiler_args = []
if compilation_helper.libsycl_dir:
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
if COMPILATION_HELPER.libsycl_dir:
extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir]
so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir,
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), file_name, binary=True)
import importlib.util
Expand Down
Loading