@@ -27,7 +27,7 @@ def quiet():
2727
2828def _cc_cmd (cc , src , out , include_dirs , library_dirs , libraries ):
2929 if cc in ["cl" , "clang-cl" ]:
30- cc_cmd = [cc , src , "/nologo" , "/O2" , "/LD" , "-std:c++20" ]
30+ cc_cmd = [cc , src , "/nologo" , "/O2" , "/LD" ]
3131 cc_cmd += [f"/I{ dir } " for dir in include_dirs ]
3232 cc_cmd += [f"/Fo{ os .path .join (os .path .dirname (out ), 'main.obj' )} " ]
3333 cc_cmd += ["/link" ]
@@ -37,14 +37,14 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
3737 cc_cmd += [f"/LIBPATH:{ dir } " for dir in library_dirs ]
3838 cc_cmd += [f'{ lib } .lib' for lib in libraries ]
3939 else :
40- cc_cmd = [cc , src , "-O3" , "-shared" , "-fPIC" ]
40+ cc_cmd = [cc , src , "-O3" , "-shared" ]
41+ if os .name != "nt" :
42+ cc_cmd += ["fPIC" ]
4143 cc_cmd += [f'-l{ lib } ' for lib in libraries ]
4244 cc_cmd += [f"-L{ dir } " for dir in library_dirs ]
4345 cc_cmd += [f"-I{ dir } " for dir in include_dirs ]
4446 cc_cmd += ["-o" , out ]
4547
46- if os .name == "nt" : cc_cmd .pop (cc_cmd .index ("-fPIC" ))
47-
4848 return cc_cmd
4949
5050
@@ -75,6 +75,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
7575 py_include_dir = sysconfig .get_paths (scheme = scheme )["include" ]
7676 custom_backend_dirs = set (os .getenv (var ) for var in ('TRITON_CUDACRT_PATH' , 'TRITON_CUDART_PATH' ))
7777 include_dirs = include_dirs + [srcdir , py_include_dir , * custom_backend_dirs ]
78+ extra_compiler_flags = []
7879
7980 if is_xpu ():
8081 icpx = None
@@ -83,22 +84,25 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
8384 clangpp = shutil .which ("clang++" )
8485 gxx = shutil .which ("g++" )
8586 icpx = shutil .which ("icpx" )
86- cxx = icpx or clangpp or gxx
87+ cxx = icpx if os . name == "nt" else icpx or clangpp or gxx
8788 if cxx is None :
8889 raise RuntimeError ("Failed to find C++ compiler. Please specify via CXX environment variable." )
90+ cc = cxx
8991 import numpy as np
9092 numpy_include_dir = np .get_include ()
9193 include_dirs = include_dirs + [numpy_include_dir ]
92- cc_cmd = [cxx ]
9394 if icpx is not None :
94- cc_cmd += ["-fsycl" ]
95+ extra_compiler_flags += ["-fsycl" ]
9596 else :
96- cc_cmd += ["--std=c++17" ]
97+ extra_compiler_flags += ["--std=c++17" ]
98+ if os .name == "nt" :
99+ library_dirs += [os .path .join (sysconfig .get_paths (scheme = scheme )["stdlib" ], ".." , "libs" )]
97100 else :
98101 cc_cmd = [cc ]
99102
100103 # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
101104 cc_cmd = _cc_cmd (cc , src , so , include_dirs , library_dirs , libraries )
105+ cc_cmd += extra_compiler_flags
102106
103107 if os .getenv ("VERBOSE" ):
104108 print (" " .join (cc_cmd ))
0 commit comments