33import subprocess
44import sys
55import tempfile
6+ import shutil
7+ import sysconfig
68
79import numpy as np
810
@@ -99,19 +101,34 @@ def kernel(C, A, B, M, N, K,
99101}"""
100102
101103
104+ def select_compiler ():
105+ gxx = shutil .which ("g++" )
106+ icpx = shutil .which ("icpx" )
107+ cl = shutil .which ("cl" )
108+ cxx = (icpx or cl ) if os .name == "nt" else (icpx or gxx )
109+ if cxx is None :
110+ raise RuntimeError ("Failed to find C++ compiler. Please specify via CXX environment variable." )
111+ return cxx
112+
113+
102114def gen_kernel_library_xpu (dir , libname ):
103115 cpp_files = glob .glob (os .path .join (dir , "*.cpp" ))
104- subprocess .run (
105- ["g++" ] + cpp_files + ["-I" + include_dir for include_dir in COMPILATION_HELPER .include_dir ] + ["-c" , "-fPIC" ],
106- check = True ,
107- cwd = dir ,
108- )
116+ cxx = select_compiler ()
117+ command = [cxx ] + cpp_files + ["-I" + include_dir for include_dir in COMPILATION_HELPER .include_dir
118+ ] + ["-c" , "-fPIC" if os .name != "nt" else "-Wno-deprecated-declarations" ]
119+ subprocess .run (command , check = True , cwd = dir )
109120 o_files = glob .glob (os .path .join (dir , "*.o" ))
110121
111- subprocess .run (["g++" ] + [* o_files , "-shared" , "-o" , libname ] +
112- ["-L" + library_dir for library_dir in COMPILATION_HELPER .library_dir ] +
113- ["-L" + dir
114- for dir in COMPILATION_HELPER .libsycl_dir ] + ["-lsycl" , "-lze_loader" ], check = True , cwd = dir )
122+ extra_link_args = []
123+ if "icpx" in cxx and os .name == "nt" :
124+ libname_without_ext = libname .split ("." )[0 ]
125+ extra_link_args = [f"/IMPLIB:{ libname_without_ext } .lib" ]
126+
127+ command = [cxx ] + [* o_files , "-shared" , "-o" , libname ] + [
128+ "-L" + library_dir for library_dir in COMPILATION_HELPER .library_dir
129+ ] + ["-L" + dir for dir in COMPILATION_HELPER .libsycl_dir
130+ ] + ["-lsycl8" if os .name == "nt" else "-lsycl" , "-lze_loader" ] + extra_link_args
131+ subprocess .run (command , check = True , cwd = dir )
115132
116133
117134def gen_kernel_library (dir , libname ):
@@ -133,6 +150,8 @@ def gen_kernel_library(dir, libname):
133150
134151
135152def gen_test_bin (dir , M , N , K , exe = "test" , algo_id = 0 ):
153+ exe_extension = sysconfig .get_config_var ("EXE" )
154+ exe = exe + exe_extension
136155 test_src = f"""
137156int main(int argc, char **argv) {{
138157 int M = { M } , N = { N } , K = { K } ;
@@ -294,15 +313,18 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
294313 command .extend (["-l" , "cuda" , "-L" , dir , "-l" , "kernel" , "-o" , exe ])
295314
296315 if is_xpu ():
297- command = ["g++" , "test.cpp" ]
316+ cxx = select_compiler ()
317+ command = [cxx , "test.cpp" ]
298318 for inc_dir in COMPILATION_HELPER .include_dir :
299319 command .extend (["-I" , inc_dir ])
300320 for lib_dir in COMPILATION_HELPER .library_dir :
301321 command .extend (["-L" , lib_dir ])
302322 if COMPILATION_HELPER .libsycl_dir :
303323 for lib_dir in COMPILATION_HELPER .libsycl_dir :
304324 command .extend (["-L" , lib_dir ])
305- command .extend (["-lsycl" , "-lze_loader" , "-L" , dir , "-l" , "kernel" , "-o" , exe ])
325+ if os .name == "nt" :
326+ command .extend (["-Wno-deprecated-declarations" ])
327+ command .extend (["-lsycl8" if os .name == "nt" else "-lsycl" , "-lze_loader" , "-L" , dir , "-lkernel" , "-o" , exe ])
306328 subprocess .run (command , check = True , cwd = dir )
307329
308330
@@ -415,7 +437,7 @@ def test_compile_link_matmul_no_specialization():
415437
416438 # compile test case
417439 M , N , K = 16 , 16 , 16
418- gen_kernel_library (tmp_dir , "libkernel.so" )
440+ gen_kernel_library (tmp_dir , "libkernel.so" if os . name != "nt" else "kernel.dll" )
419441 gen_test_bin (tmp_dir , M , N , K )
420442
421443 # initialize test data
@@ -424,7 +446,7 @@ def test_compile_link_matmul_no_specialization():
424446 # run test case
425447 env = os .environ .copy ()
426448 env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env .get ("LD_LIBRARY_PATH" , "" )
427- subprocess .run (["./ test" , a_path , b_path , c_path ], env = env , check = True , cwd = tmp_dir )
449+ subprocess .run ([os . path . join ( tmp_dir , " test") , a_path , b_path , c_path ], env = env , check = True , cwd = tmp_dir )
428450 # read data and compare against reference
429451 c = np .genfromtxt (c_path , delimiter = "," , dtype = np .int32 )
430452 c_tri = c .reshape ((M , N )).view (np .float32 )
@@ -445,7 +467,7 @@ def test_compile_link_matmul():
445467
446468 # compile test case
447469 M , N , K = 16 , 16 , 16
448- gen_kernel_library (tmp_dir , "libkernel.so" )
470+ gen_kernel_library (tmp_dir , "libkernel.so" if os . name != "nt" else "kernel.dll" )
449471 gen_test_bin (tmp_dir , M , N , K )
450472
451473 # initialize test data
@@ -454,7 +476,7 @@ def test_compile_link_matmul():
454476 # run test case
455477 env = os .environ .copy ()
456478 env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env .get ("LD_LIBRARY_PATH" , "" )
457- subprocess .run (["./ test" , a_path , b_path , c_path ], env = env , check = True , cwd = tmp_dir )
479+ subprocess .run ([os . path . join ( tmp_dir , " test") , a_path , b_path , c_path ], env = env , check = True , cwd = tmp_dir )
458480
459481 # read data and compare against reference
460482 c = np .genfromtxt (c_path , delimiter = "," , dtype = np .int32 )
@@ -476,7 +498,7 @@ def test_launcher_has_no_available_kernel():
476498
477499 # compile test case
478500 M , N , K = 16 , 16 , 16
479- gen_kernel_library (tmp_dir , "libkernel.so" )
501+ gen_kernel_library (tmp_dir , "libkernel.so" if os . name != "nt" else "kernel.dll" )
480502 gen_test_bin (tmp_dir , M , N , K )
481503
482504 # initialize test data
@@ -486,15 +508,16 @@ def test_launcher_has_no_available_kernel():
486508 env = os .environ .copy ()
487509 env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env .get ("LD_LIBRARY_PATH" , "" )
488510 result = subprocess .run (
489- ["./ test" , a_path , b_path , c_path ],
511+ [os . path . join ( tmp_dir , " test") , a_path , b_path , c_path ],
490512 env = env ,
491513 cwd = tmp_dir ,
492514 capture_output = True ,
493515 text = True ,
494516 )
495517
496518 # It should fail since the launcher requires all the strides be 1 while they are not.
497- assert result .returncode == - 6
519+ # On windows: 3221226505 == 0xc0000409: STATUS_STACK_BUFFER_OVERRUN
520+ assert result .returncode == - 6 if os .name != "nt" else 0xc0000409
498521 assert "kernel launch failed" in result .stderr
499522
500523
@@ -519,7 +542,7 @@ def test_compile_link_autotune_matmul():
519542
520543 link_aot_kernels (tmp_dir )
521544
522- gen_kernel_library (tmp_dir , "libkernel.so" )
545+ gen_kernel_library (tmp_dir , "libkernel.so" if os . name != "nt" else "kernel.dll" )
523546
524547 # compile test case
525548 M , N , K = 64 , 64 , 64
@@ -535,7 +558,7 @@ def test_compile_link_autotune_matmul():
535558 env = os .environ .copy ()
536559 env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env .get ("LD_LIBRARY_PATH" , "" )
537560 subprocess .run (
538- [f"./ { test_name } " , a_path , b_path , c_path ],
561+ [os . path . join ( tmp_dir , test_name ) , a_path , b_path , c_path ],
539562 check = True ,
540563 cwd = tmp_dir ,
541564 env = env ,
0 commit comments