Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def quiet():
sys.stdout, sys.stderr = old_stdout, old_stderr


def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change in common file should be upstreamed. Can you open an issue so you can do that separately please ?

suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
Expand Down Expand Up @@ -74,15 +74,14 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
cc_cmd += extra_compile_args

if os.getenv("VERBOSE"):
print(" ".join(cc_cmd))

ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
# extra arguments
extra_link_args = []
# create extension module
Expand Down
32 changes: 23 additions & 9 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
from pathlib import Path
from functools import cached_property
from typing import Optional

from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
Expand All @@ -14,36 +15,35 @@
from packaging.specifiers import SpecifierSet


def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
"""
Looks for the sycl library in known places.

Arguments:
include_dir: list of include directories to pass to compiler.

Returns:
enriched include_dir and library_dir.
enriched include_dir and libsycl.so location.

Raises:
AssertionError: if library was not found.
"""
library_dir = []
include_dir = include_dir.copy()
assertion_message = ("sycl headers not found, please install `icpx` compiler, "
"or provide `ONEAPI_ROOT` environment "
"or install `intel-sycl-rt>=2025.0.0` wheel")

if shutil.which("icpx"):
# only `icpx` compiler knows where sycl runtime binaries and header files are
return include_dir, library_dir
return include_dir, None

oneapi_root = os.getenv("ONEAPI_ROOT")
if oneapi_root:
include_dir += [
os.path.join(oneapi_root, "compiler/latest/include"),
os.path.join(oneapi_root, "compiler/latest/include/sycl")
]
return include_dir, library_dir
return include_dir, None

try:
sycl_rt = importlib.metadata.metadata("intel-sycl-rt")
Expand All @@ -53,15 +53,16 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
if Version(sycl_rt.get("version", "0.0.0")) in SpecifierSet("<2025.0.0a1"):
raise AssertionError(assertion_message)

sycl_dir = None
for f in importlib.metadata.files("intel-sycl-rt"):
# sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
# being add: include and include/sycl.
if f.name == "sycl.hpp":
include_dir += [f.locate().parent.parent.resolve().as_posix()]
if f.name == "libsycl.so":
library_dir += [f.locate().parent.resolve().as_posix()]
sycl_dir = f.locate().parent.resolve().as_posix()

return include_dir, library_dir
return include_dir, sycl_dir


class CompilationHelper:
Expand All @@ -71,17 +72,22 @@ class CompilationHelper:
def __init__(self):
self._library_dir = None
self._include_dir = None
self._libsycl_dir = None
self.libraries = ['ze_loader', 'sycl']

@cached_property
def _compute_compilation_options_lazy(self):
ze_root = os.getenv("ZE_PATH", default="/usr/local")
include_dir = [os.path.join(ze_root, "include")]

include_dir, library_dir = find_sycl(include_dir)
library_dir = []
include_dir, self._libsycl_dir = find_sycl(include_dir)
if self._libsycl_dir:
library_dir += [self._libsycl_dir]

dirname = os.path.dirname(os.path.realpath(__file__))
include_dir += [os.path.join(dirname, "include")]
# TODO: do we need this?
library_dir += [os.path.join(dirname, "lib")]

self._library_dir = library_dir
Expand All @@ -97,6 +103,11 @@ def include_dir(self) -> list[str]:
self._compute_compilation_options_lazy
return self._include_dir

@cached_property
def libsycl_dir(self) -> Optional[str]:
self._compute_compilation_options_lazy
return self._libsycl_dir


compilation_helper = CompilationHelper()

Expand All @@ -110,8 +121,11 @@ def compile_module_from_src(src, name):
src_path = os.path.join(tmpdir, "main.cpp")
with open(src_path, "w") as f:
f.write(src)
extra_compiler_args = []
if compilation_helper.libsycl_dir:
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
compilation_helper.libraries)
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
import importlib.util
Expand Down