Skip to content

Commit 1f73581

Browse files
authored
Enable msvc to compile kernel launchers (#3185)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 1c766f0 commit 1f73581

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

python/triton/runtime/build.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def is_xpu():
1010

1111

1212
def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
13-
if cc in ["cl", "clang-cl"]:
14-
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
13+
if "cl.EXE" in cc or "clang-cl" in cc:
14+
cc_cmd = [cc, "/Zc:__cplusplus", src, "/nologo", "/O2", "/LD"]
1515
cc_cmd += [f"/I{dir}" for dir in include_dirs]
1616
cc_cmd += [f"/Fo{os.path.join(os.path.dirname(out), 'main.obj')}"]
1717
cc_cmd += ["/link"]
@@ -66,19 +66,22 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
6666
clangpp = shutil.which("clang++")
6767
gxx = shutil.which("g++")
6868
icpx = shutil.which("icpx")
69-
cxx = icpx if os.name == "nt" else icpx or clangpp or gxx
69+
cl = shutil.which("cl")
70+
cxx = icpx or cl if os.name == "nt" else icpx or clangpp or gxx
7071
if cxx is None:
7172
raise RuntimeError("Failed to find C++ compiler. Please specify via CXX environment variable.")
7273
cc = cxx
7374
import numpy as np
7475
numpy_include_dir = np.get_include()
7576
include_dirs = include_dirs + [numpy_include_dir]
76-
if icpx is not None:
77+
if cxx is icpx:
7778
extra_compile_args += ["-fsycl"]
7879
else:
7980
extra_compile_args += ["--std=c++17"]
8081
if os.name == "nt":
81-
library_dirs = library_dirs + [os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs")]
82+
library_dirs = library_dirs + [
83+
os.path.abspath(os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs"))
84+
]
8285
else:
8386
cc_cmd = [cc]
8487

third_party/intel/backend/driver.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,12 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], str]:
6060
# sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
6161
# being add: include and include/sycl.
6262
if f.name == "sycl.hpp":
63-
include_dir += [f.locate().parent.parent.resolve().as_posix()]
64-
if f.name == "libsycl.so":
65-
sycl_dir = f.locate().parent.resolve().as_posix()
63+
include_dir += [str(f.locate().parent.parent.resolve())]
64+
if f.name in ["libsycl.so", "sycl8.dll"]:
65+
sycl_dir = str(f.locate().parent.resolve())
66+
# should we handle `_` somehow?
67+
if os.name == "nt":
68+
_ = os.add_dll_directory(sycl_dir)
6669

6770
return include_dir, sycl_dir
6871

@@ -76,9 +79,7 @@ def __init__(self):
7679
self._library_dir = None
7780
self._include_dir = None
7881
self._libsycl_dir = None
79-
self.libraries = ['ze_loader']
80-
if os.name != "nt":
81-
self.libraries += ["sycl"]
82+
self.libraries = ['ze_loader', 'sycl']
8283

8384
@property
8485
def inject_pytorch_dep(self):
@@ -145,7 +146,11 @@ def compile_module_from_src(src, name):
145146
f.write(src)
146147
extra_compiler_args = []
147148
if COMPILATION_HELPER.libsycl_dir:
148-
extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir]
149+
if os.name == "nt":
150+
extra_compiler_args += ["/LIBPATH:" + COMPILATION_HELPER.libsycl_dir]
151+
else:
152+
extra_compiler_args += ["-Wl,-rpath," + COMPILATION_HELPER.libsycl_dir]
153+
149154
so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir,
150155
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
151156
with open(so, "rb") as f:

0 commit comments

Comments
 (0)