Skip to content

Commit 92fc8da

Browse files
Merge commit '663e04e8e3ebed7ee3230a1a7320142689795106'
2 parents 563c2c1 + 663e04e commit 92fc8da

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

python/triton/runtime/build.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
4747

4848

4949
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
50-
extra_compile_args: list[str] = []) -> str:
50+
ccflags: list[str] = []) -> str:
5151
if impl := knobs.build.impl:
52-
return impl(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args)
52+
return impl(name, src, srcdir, library_dirs, include_dirs, libraries, ccflags)
5353
suffix = sysconfig.get_config_var('EXT_SUFFIX')
5454
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
5555
# try to avoid setuptools if possible
@@ -92,12 +92,12 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
9292
numpy_include_dir = np.get_include()
9393
include_dirs = include_dirs + [numpy_include_dir]
9494
if cxx is icpx:
95-
extra_compile_args += ["-fsycl"]
95+
ccflags += ["-fsycl"]
9696
else:
9797
if os.name != "nt":
98-
extra_compile_args += ["--std=c++17"]
98+
ccflags += ["--std=c++17"]
9999
if os.environ.get("TRITON_SUPPRESS_GCC_HOST_CODE_DEPRECATION_WARNINGS", "1") == "1":
100-
extra_compile_args += ["-Wno-deprecated-declarations"]
100+
ccflags += ["-Wno-deprecated-declarations"]
101101
if os.name == "nt":
102102
library_dirs = library_dirs + [
103103
os.path.abspath(os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs"))
@@ -107,7 +107,7 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
107107

108108
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
109109
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
110-
cc_cmd += extra_compile_args
110+
cc_cmd += ccflags
111111

112112
if os.getenv("VERBOSE"):
113113
print(" ".join(cc_cmd))
@@ -132,7 +132,8 @@ def _load_module_from_path(name: str, path: str) -> ModuleType:
132132

133133

134134
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
135-
include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType:
135+
include_dirs: list[str] | None = None, libraries: list[str] | None = None,
136+
ccflags: list[str] | None = None) -> ModuleType:
136137
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
137138
cache = get_cache_manager(key)
138139
suffix = sysconfig.get_config_var("EXT_SUFFIX")
@@ -149,7 +150,7 @@ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None
149150
src_path = os.path.join(tmpdir, name + ".c")
150151
with open(src_path, "w") as f:
151152
f.write(src)
152-
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
153+
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
153154
with open(so, "rb") as f:
154155
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
155156

0 commit comments

Comments
 (0)