Skip to content

Commit 616e223

Browse files
committed
Make loader buildable on windows
Signed-off-by: Gregory Shimansky <[email protected]>
1 parent 14c6b13 commit 616e223

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

python/triton/runtime/build.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def quiet():
2727

2828
def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
2929
if cc in ["cl", "clang-cl"]:
30-
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "-std:c++20"]
30+
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
3131
cc_cmd += [f"/I{dir}" for dir in include_dirs]
3232
cc_cmd += [f"/Fo{os.path.join(os.path.dirname(out), 'main.obj')}"]
3333
cc_cmd += ["/link"]
@@ -37,14 +37,14 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
3737
cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
3838
cc_cmd += [f'{lib}.lib' for lib in libraries]
3939
else:
40-
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"]
40+
cc_cmd = [cc, src, "-O3", "-shared"]
41+
if os.name != "nt":
42+
cc_cmd += ["fPIC"]
4143
cc_cmd += [f'-l{lib}' for lib in libraries]
4244
cc_cmd += [f"-L{dir}" for dir in library_dirs]
4345
cc_cmd += [f"-I{dir}" for dir in include_dirs]
4446
cc_cmd += ["-o", out]
4547

46-
if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC"))
47-
4848
return cc_cmd
4949

5050

@@ -75,6 +75,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
7575
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
7676
custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH'))
7777
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
78+
extra_compiler_flags = []
7879

7980
if is_xpu():
8081
icpx = None
@@ -83,22 +84,25 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
8384
clangpp = shutil.which("clang++")
8485
gxx = shutil.which("g++")
8586
icpx = shutil.which("icpx")
86-
cxx = icpx or clangpp or gxx
87+
cxx = icpx if os.name == "nt" else icpx or clangpp or gxx
8788
if cxx is None:
8889
raise RuntimeError("Failed to find C++ compiler. Please specify via CXX environment variable.")
90+
cc = cxx
8991
import numpy as np
9092
numpy_include_dir = np.get_include()
9193
include_dirs = include_dirs + [numpy_include_dir]
92-
cc_cmd = [cxx]
9394
if icpx is not None:
94-
cc_cmd += ["-fsycl"]
95+
extra_compiler_flags += ["-fsycl"]
9596
else:
96-
cc_cmd += ["--std=c++17"]
97+
extra_compiler_flags += ["--std=c++17"]
98+
if os.name == "nt":
99+
library_dirs += [os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs")]
97100
else:
98101
cc_cmd = [cc]
99102

100103
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
101104
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
105+
cc_cmd += extra_compiler_flags
102106

103107
if os.getenv("VERBOSE"):
104108
print(" ".join(cc_cmd))

third_party/intel/backend/driver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], list[str]]:
3333
"or provide `ONEAPI_ROOT` environment "
3434
"or install `intel-sycl-rt>=2025.0.0` wheel")
3535

36-
if shutil.which("icpx"):
36+
if shutil.which("icpx") and os.name != "nt":
3737
# only `icpx` compiler knows where sycl runtime binaries and header files are
3838
return include_dir, library_dir
3939

@@ -71,14 +71,18 @@ class CompilationHelper:
7171
def __init__(self):
7272
self._library_dir = None
7373
self._include_dir = None
74-
self.libraries = ['ze_loader', 'sycl']
74+
self.libraries = ['ze_loader']
75+
if os.name != "nt":
76+
self.libraries += ["sycl"]
7577

7678
@cached_property
7779
def _compute_compilation_options_lazy(self):
7880
ze_root = os.getenv("ZE_PATH", default="/usr/local")
7981
include_dir = [os.path.join(ze_root, "include")]
8082

8183
include_dir, library_dir = find_sycl(include_dir)
84+
if os.name == "nt":
85+
library_dir += [os.path.join(ze_root, "lib")]
8286

8387
dirname = os.path.dirname(os.path.realpath(__file__))
8488
include_dir += [os.path.join(dirname, "include")]

0 commit comments

Comments
 (0)