Skip to content

Commit 663e04e

Browse files
authored
[python] Allow backend-provided runtime host compiler flags (#7659)
Modifies the `compile_module_from_src` function (and `_build`) to allow third party plugins to provide arbitrary cc compiler flags when compiling modules from source at runtime. This is useful for backends like Intel (`-fsycl`) or for the CPU project (`-fopenmp`) and also for supporting dynamic lookup for the Python library when running on macOS.
1 parent b3b9931 commit 663e04e

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python/triton/runtime/build.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from .. import knobs
1717

1818

19-
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
20-
libraries: list[str]) -> str:
19+
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
20+
ccflags: list[str]) -> str:
2121
if impl := knobs.build.impl:
2222
return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
2323
suffix = sysconfig.get_config_var('EXT_SUFFIX')
@@ -48,6 +48,7 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
4848
cc_cmd += [f'-l{lib}' for lib in libraries]
4949
cc_cmd += [f"-L{dir}" for dir in library_dirs]
5050
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
51+
cc_cmd.extend(ccflags)
5152
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
5253
return so
5354

@@ -68,7 +69,8 @@ def _load_module_from_path(name: str, path: str) -> ModuleType:
6869

6970

7071
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
71-
include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType:
72+
include_dirs: list[str] | None = None, libraries: list[str] | None = None,
73+
ccflags: list[str] | None = None) -> ModuleType:
7274
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
7375
cache = get_cache_manager(key)
7476
suffix = sysconfig.get_config_var("EXT_SUFFIX")
@@ -85,7 +87,7 @@ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None
8587
src_path = os.path.join(tmpdir, name + ".c")
8688
with open(src_path, "w") as f:
8789
f.write(src)
88-
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
90+
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
8991
with open(so, "rb") as f:
9092
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
9193

0 commit comments

Comments
 (0)