Skip to content

Commit 3d2944d

Browse files
authored
Enable tools/test_aot.py for Windows+icpx (#3629)
I plan to add msvc support separately. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7375302 commit 3d2944d

File tree

2 files changed

+60
-27
lines changed

2 files changed

+60
-27
lines changed

python/test/unit/tools/test_aot.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import subprocess
44
import sys
55
import tempfile
6+
import shutil
7+
import sysconfig
68

79
import numpy as np
810

@@ -99,19 +101,34 @@ def kernel(C, A, B, M, N, K,
99101
}"""
100102

101103

104+
def select_compiler():
105+
gxx = shutil.which("g++")
106+
icpx = shutil.which("icpx")
107+
cl = shutil.which("cl")
108+
cxx = (icpx or cl) if os.name == "nt" else (icpx or gxx)
109+
if cxx is None:
110+
raise RuntimeError("Failed to find C++ compiler. Please specify via CXX environment variable.")
111+
return cxx
112+
113+
102114
def gen_kernel_library_xpu(dir, libname):
103115
cpp_files = glob.glob(os.path.join(dir, "*.cpp"))
104-
subprocess.run(
105-
["g++"] + cpp_files + ["-I" + include_dir for include_dir in COMPILATION_HELPER.include_dir] + ["-c", "-fPIC"],
106-
check=True,
107-
cwd=dir,
108-
)
116+
cxx = select_compiler()
117+
command = [cxx] + cpp_files + ["-I" + include_dir for include_dir in COMPILATION_HELPER.include_dir
118+
] + ["-c", "-fPIC" if os.name != "nt" else "-Wno-deprecated-declarations"]
119+
subprocess.run(command, check=True, cwd=dir)
109120
o_files = glob.glob(os.path.join(dir, "*.o"))
110121

111-
subprocess.run(["g++"] + [*o_files, "-shared", "-o", libname] +
112-
["-L" + library_dir for library_dir in COMPILATION_HELPER.library_dir] +
113-
["-L" + dir
114-
for dir in COMPILATION_HELPER.libsycl_dir] + ["-lsycl", "-lze_loader"], check=True, cwd=dir)
122+
extra_link_args = []
123+
if "icpx" in cxx and os.name == "nt":
124+
libname_without_ext = libname.split(".")[0]
125+
extra_link_args = [f"/IMPLIB:{libname_without_ext}.lib"]
126+
127+
command = [cxx] + [*o_files, "-shared", "-o", libname] + [
128+
"-L" + library_dir for library_dir in COMPILATION_HELPER.library_dir
129+
] + ["-L" + dir for dir in COMPILATION_HELPER.libsycl_dir
130+
] + ["-lsycl8" if os.name == "nt" else "-lsycl", "-lze_loader"] + extra_link_args
131+
subprocess.run(command, check=True, cwd=dir)
115132

116133

117134
def gen_kernel_library(dir, libname):
@@ -133,6 +150,8 @@ def gen_kernel_library(dir, libname):
133150

134151

135152
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
153+
exe_extension = sysconfig.get_config_var("EXE")
154+
exe = exe + exe_extension
136155
test_src = f"""
137156
int main(int argc, char **argv) {{
138157
int M = {M}, N = {N}, K = {K};
@@ -294,15 +313,18 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
294313
command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe])
295314

296315
if is_xpu():
297-
command = ["g++", "test.cpp"]
316+
cxx = select_compiler()
317+
command = [cxx, "test.cpp"]
298318
for inc_dir in COMPILATION_HELPER.include_dir:
299319
command.extend(["-I", inc_dir])
300320
for lib_dir in COMPILATION_HELPER.library_dir:
301321
command.extend(["-L", lib_dir])
302322
if COMPILATION_HELPER.libsycl_dir:
303323
for lib_dir in COMPILATION_HELPER.libsycl_dir:
304324
command.extend(["-L", lib_dir])
305-
command.extend(["-lsycl", "-lze_loader", "-L", dir, "-l", "kernel", "-o", exe])
325+
if os.name == "nt":
326+
command.extend(["-Wno-deprecated-declarations"])
327+
command.extend(["-lsycl8" if os.name == "nt" else "-lsycl", "-lze_loader", "-L", dir, "-lkernel", "-o", exe])
306328
subprocess.run(command, check=True, cwd=dir)
307329

308330

@@ -415,7 +437,7 @@ def test_compile_link_matmul_no_specialization():
415437

416438
# compile test case
417439
M, N, K = 16, 16, 16
418-
gen_kernel_library(tmp_dir, "libkernel.so")
440+
gen_kernel_library(tmp_dir, "libkernel.so" if os.name != "nt" else "kernel.dll")
419441
gen_test_bin(tmp_dir, M, N, K)
420442

421443
# initialize test data
@@ -424,7 +446,7 @@ def test_compile_link_matmul_no_specialization():
424446
# run test case
425447
env = os.environ.copy()
426448
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
427-
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
449+
subprocess.run([os.path.join(tmp_dir, "test"), a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
428450
# read data and compare against reference
429451
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
430452
c_tri = c.reshape((M, N)).view(np.float32)
@@ -445,7 +467,7 @@ def test_compile_link_matmul():
445467

446468
# compile test case
447469
M, N, K = 16, 16, 16
448-
gen_kernel_library(tmp_dir, "libkernel.so")
470+
gen_kernel_library(tmp_dir, "libkernel.so" if os.name != "nt" else "kernel.dll")
449471
gen_test_bin(tmp_dir, M, N, K)
450472

451473
# initialize test data
@@ -454,7 +476,7 @@ def test_compile_link_matmul():
454476
# run test case
455477
env = os.environ.copy()
456478
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
457-
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
479+
subprocess.run([os.path.join(tmp_dir, "test"), a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
458480

459481
# read data and compare against reference
460482
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
@@ -476,7 +498,7 @@ def test_launcher_has_no_available_kernel():
476498

477499
# compile test case
478500
M, N, K = 16, 16, 16
479-
gen_kernel_library(tmp_dir, "libkernel.so")
501+
gen_kernel_library(tmp_dir, "libkernel.so" if os.name != "nt" else "kernel.dll")
480502
gen_test_bin(tmp_dir, M, N, K)
481503

482504
# initialize test data
@@ -486,15 +508,16 @@ def test_launcher_has_no_available_kernel():
486508
env = os.environ.copy()
487509
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
488510
result = subprocess.run(
489-
["./test", a_path, b_path, c_path],
511+
[os.path.join(tmp_dir, "test"), a_path, b_path, c_path],
490512
env=env,
491513
cwd=tmp_dir,
492514
capture_output=True,
493515
text=True,
494516
)
495517

496518
# It should fail since the launcher requires all the strides be 1 while they are not.
497-
assert result.returncode == -6
519+
# On windows: 3221226505 == 0xc0000409: STATUS_STACK_BUFFER_OVERRUN
520+
assert result.returncode == -6 if os.name != "nt" else 0xc0000409
498521
assert "kernel launch failed" in result.stderr
499522

500523

@@ -519,7 +542,7 @@ def test_compile_link_autotune_matmul():
519542

520543
link_aot_kernels(tmp_dir)
521544

522-
gen_kernel_library(tmp_dir, "libkernel.so")
545+
gen_kernel_library(tmp_dir, "libkernel.so" if os.name != "nt" else "kernel.dll")
523546

524547
# compile test case
525548
M, N, K = 64, 64, 64
@@ -535,7 +558,7 @@ def test_compile_link_autotune_matmul():
535558
env = os.environ.copy()
536559
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
537560
subprocess.run(
538-
[f"./{test_name}", a_path, b_path, c_path],
561+
[os.path.join(tmp_dir, test_name), a_path, b_path, c_path],
539562
check=True,
540563
cwd=tmp_dir,
541564
env=env,

python/triton/tools/link.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def make_global_decl(meta: KernelLinkerMeta) -> str:
159159
"""
160160
if is_xpu():
161161
return f"""
162-
int32_t {meta.orig_kernel_name}_default(sycl::queue &stream, {gen_signature_with_full_args(meta)});
163-
int32_t {meta.orig_kernel_name}(sycl::queue &stream, {gen_signature_with_full_args(meta)}, int algo_id);
164-
void load_{meta.orig_kernel_name}();
165-
void unload_{meta.orig_kernel_name}();
162+
EXPORT_FUNC int32_t {meta.orig_kernel_name}_default(sycl::queue &stream, {gen_signature_with_full_args(meta)});
163+
EXPORT_FUNC int32_t {meta.orig_kernel_name}(sycl::queue &stream, {gen_signature_with_full_args(meta)}, int algo_id);
164+
EXPORT_FUNC void load_{meta.orig_kernel_name}();
165+
EXPORT_FUNC void unload_{meta.orig_kernel_name}();
166166
"""
167167

168168

@@ -172,7 +172,7 @@ def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
172172
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
173173
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
174174
if is_xpu():
175-
src = f"int32_t {meta.orig_kernel_name}_default(sycl::queue &stream, {gen_signature_with_full_args(meta)}){{\n"
175+
src = f"EXPORT_FUNC int32_t {meta.orig_kernel_name}_default(sycl::queue &stream, {gen_signature_with_full_args(meta)}){{\n"
176176
src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n"
177177
src += "}\n"
178178
return src
@@ -245,7 +245,7 @@ def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
245245
if is_cuda():
246246
src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
247247
if is_xpu():
248-
src = f"int32_t {meta.orig_kernel_name}(sycl::queue &stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
248+
src = f"EXPORT_FUNC int32_t {meta.orig_kernel_name}(sycl::queue &stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
249249
src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n"
250250
if is_cuda():
251251
src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n"
@@ -273,7 +273,7 @@ def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str:
273273
def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str:
274274
src = ""
275275
for mode in ["load", "unload"]:
276-
src += f"void {mode}_{meta.orig_kernel_name}(void){{\n"
276+
src += f"EXPORT_FUNC void {mode}_{meta.orig_kernel_name}(void){{\n"
277277
for name in names:
278278
src += f" {mode}_{name}();\n"
279279
src += "}\n\n"
@@ -344,6 +344,11 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
344344
out += "#include <sycl/sycl.hpp>\n"
345345
out += "#include <stdint.h>\n"
346346
out += "#include <stdio.h>\n"
347+
out += "#if defined(_WIN32)\n"
348+
out += "#define EXPORT_FUNC __declspec(dllexport)\n"
349+
out += "#else\n"
350+
out += "#define EXPORT_FUNC\n"
351+
out += "#endif\n"
347352
out += "\n".join(algo_decls)
348353
out += "\n"
349354
out += get_num_algos_decl
@@ -386,6 +391,11 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
386391
out += "#include <stdint.h>\n"
387392
out += "#include <assert.h>\n"
388393
out += "#include <cstdint>\n"
394+
out += "#if defined(_WIN32)\n"
395+
out += "#define EXPORT_FUNC __declspec(dllexport)\n"
396+
out += "#else\n"
397+
out += "#define EXPORT_FUNC\n"
398+
out += "#endif\n"
389399
out += "\n"
390400
out += "\n".join(defs)
391401
out += "\n"

0 commit comments

Comments
 (0)