Skip to content

Commit 513aede

Browse files
authored
Windows native port (#2478)
Fixes #2407. Current state is that code builds on windows and is able to pass many unit tests. More fixes will be done when this patch is integrated into main branch. --------- Signed-off-by: Gregory Shimansky <[email protected]>
1 parent d254e2b commit 513aede

File tree

7 files changed

+229
-33
lines changed

7 files changed

+229
-33
lines changed

CMakeLists.txt

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,24 @@ endif()
88

99
include(ExternalProject)
1010

11-
set(CMAKE_CXX_STANDARD 17)
12-
1311
set(CMAKE_INCLUDE_CURRENT_DIR ON)
1412

1513
project(triton CXX)
1614
include(CTest)
1715

18-
if(NOT WIN32)
19-
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
20-
endif()
21-
22-
16+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
2317

2418
# Options
19+
if(WIN32)
20+
set(DEFAULT_BUILD_PROTON OFF)
21+
else()
22+
set(DEFAULT_BUILD_PROTON ON)
23+
endif()
24+
25+
# Define the option with the determined default value
26+
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON})
2527
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
2628
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
27-
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2829
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2930
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
3031
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
@@ -49,10 +50,21 @@ endif()
4950
# used conditionally in this file and by lit tests
5051

5152
# Customized release build type with assertions: TritonRelBuildWithAsserts
52-
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
53-
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
54-
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
55-
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
53+
if(NOT MSVC)
54+
set(CMAKE_CXX_STANDARD 17)
55+
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
56+
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
57+
set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1")
58+
set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1")
59+
else()
60+
set(CMAKE_CXX_STANDARD 20)
61+
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
62+
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor")
63+
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
64+
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
65+
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
66+
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
67+
endif()
5668

5769
# Default build type
5870
if(NOT CMAKE_BUILD_TYPE)
@@ -70,7 +82,15 @@ endif()
7082

7183
# Compiler flags
7284
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
73-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
85+
if(NOT MSVC)
86+
if(NOT WIN32)
87+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC")
88+
else()
89+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -Wno-deprecated")
90+
endif()
91+
else()
92+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
93+
endif()
7494

7595

7696
# #########
@@ -124,7 +144,11 @@ endfunction()
124144

125145

126146
# Disable warnings that show up in external code (gtest;pybind11)
127-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
147+
if(NOT MSVC)
148+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
149+
else()
150+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
151+
endif()
128152

129153
include_directories(".")
130154
include_directories(${MLIR_INCLUDE_DIRS})
@@ -134,7 +158,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
134158
include_directories(${PROJECT_SOURCE_DIR}/third_party)
135159
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files
136160

137-
# link_directories(${LLVM_LIBRARY_DIR})
161+
link_directories(${LLVM_LIBRARY_DIR})
162+
138163
add_subdirectory(include)
139164
add_subdirectory(lib)
140165

@@ -163,6 +188,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
163188
# using pip install.
164189
include_directories(${PYTHON_INCLUDE_DIRS})
165190
include_directories(${PYBIND11_INCLUDE_DIR})
191+
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
192+
link_directories(${PYTHON_LIB_DIRS})
166193
else()
167194
# Otherwise, we might be building from top CMakeLists.txt directly.
168195
# Try to find Python and pybind11 packages.
@@ -245,7 +272,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
245272
LLVMAArch64CodeGen
246273
LLVMAArch64AsmParser
247274
)
248-
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
275+
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
249276
list(APPEND TRITON_LIBRARIES
250277
LLVMX86CodeGen
251278
LLVMX86AsmParser
@@ -280,6 +307,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
280307
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
281308
if(WIN32)
282309
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
310+
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
311+
set_target_properties(triton PROPERTIES PREFIX "lib")
283312
else()
284313
target_link_libraries(triton PRIVATE z)
285314
endif()
@@ -306,6 +335,10 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
306335
add_subdirectory(third_party/${CODEGEN_BACKEND})
307336
endforeach()
308337
endif()
338+
if(WIN32)
339+
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
340+
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
341+
endif()
309342

310343
add_subdirectory(third_party/f2reduce)
311344
add_subdirectory(bin)

python/setup.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,49 @@ def copy_externals():
103103
]
104104

105105

106+
def find_vswhere():
107+
program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")
108+
vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe"
109+
if vswhere_path.exists():
110+
return vswhere_path
111+
return None
112+
113+
114+
def find_visual_studio(version_ranges):
115+
vswhere = find_vswhere()
116+
if not vswhere:
117+
raise FileNotFoundError("vswhere.exe not found.")
118+
119+
for version_range in version_ranges:
120+
command = [
121+
str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
122+
"-property", "installationPath", "-prerelease"
123+
]
124+
125+
try:
126+
output = subprocess.check_output(command, text=True).strip()
127+
if output:
128+
return output
129+
except subprocess.CalledProcessError:
130+
continue
131+
132+
return None
133+
134+
135+
def set_env_vars(vs_path, arch="x64"):
136+
vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
137+
if not vcvarsall_path.exists():
138+
raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}")
139+
140+
command = ["call", vcvarsall_path, arch, "&&", "set"]
141+
output = subprocess.check_output(command, shell=True, text=True)
142+
143+
for line in output.splitlines():
144+
if '=' in line:
145+
var, value = line.split('=', 1)
146+
os.environ[var] = value
147+
148+
106149
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
107150
def check_env_flag(name: str, default: str = "") -> bool:
108151
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
@@ -196,6 +239,8 @@ def get_llvm_package_info():
196239
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
197240
)
198241
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
242+
elif system == 'Windows':
243+
system_suffix = "windows-x64"
199244
else:
200245
print(
201246
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
@@ -281,10 +326,10 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
281326
base_dir = os.path.dirname(__file__)
282327
system = platform.system()
283328
try:
284-
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
329+
arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
285330
except KeyError:
286331
arch = platform.machine()
287-
supported = {"Linux": "linux", "Darwin": "linux"}
332+
supported = {"Linux": "linux", "Darwin": "linux", "Windows": "win"}
288333
url = url_func(supported[system], arch, version)
289334
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
290335
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
@@ -401,6 +446,11 @@ def get_proton_cmake_args(self):
401446
def build_extension(self, ext):
402447
lit_dir = shutil.which('lit')
403448
ninja_dir = shutil.which('ninja')
449+
if platform.system() == "Windows":
450+
vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"])
451+
env = set_env_vars(vs_path)
452+
if not vs_path:
453+
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
404454
# lit is used by the test suite
405455
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
406456
thirdparty_cmake_args += self.get_pybind11_cmake_args()
@@ -421,6 +471,10 @@ def build_extension(self, ext):
421471
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
422472
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
423473
]
474+
if platform.system() == "Windows":
475+
installed_base = sysconfig.get_config_var('installed_base')
476+
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
477+
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
424478
if lit_dir is not None:
425479
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
426480
cmake_args.extend(thirdparty_cmake_args)

python/triton/runtime/CLFinder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import subprocess
3+
from pathlib import Path
4+
5+
6+
def find_vswhere():
7+
program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)")
8+
vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe"
9+
if vswhere_path.exists():
10+
return vswhere_path
11+
return None
12+
13+
14+
def find_visual_studio(version_ranges):
15+
vswhere = find_vswhere()
16+
if not vswhere:
17+
raise FileNotFoundError("vswhere.exe not found.")
18+
19+
for version_range in version_ranges:
20+
command = [
21+
str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
22+
"-property", "installationPath", "-prerelease"
23+
]
24+
25+
try:
26+
output = subprocess.check_output(command, text=True).strip()
27+
if output:
28+
return output
29+
except subprocess.CalledProcessError:
30+
continue
31+
32+
return None
33+
34+
35+
def set_env_vars(vs_path, arch="x64"):
36+
vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat"
37+
if not vcvarsall_path.exists():
38+
raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}")
39+
40+
command = f'call "{vcvarsall_path}" {arch} && set'
41+
output = subprocess.check_output(command, shell=True, text=True)
42+
43+
for line in output.splitlines():
44+
if '=' in line:
45+
var, value = line.split('=', 1)
46+
os.environ[var] = value
47+
48+
49+
def initialize_visual_studio_env(version_ranges, arch="x64"):
50+
# Check if the environment variable that vcvarsall.bat sets is present
51+
if os.environ.get('VSCMD_ARG_TGT_ARCH') != arch:
52+
vs_path = find_visual_studio(version_ranges)
53+
if not vs_path:
54+
raise EnvironmentError("Visual Studio not found in specified version ranges.")
55+
set_env_vars(vs_path, arch)

python/triton/runtime/build.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import shutil
77
import subprocess
88
import setuptools
9+
import platform
10+
from .CLFinder import initialize_visual_studio_env
911

1012

1113
def is_xpu():
@@ -23,6 +25,29 @@ def quiet():
2325
sys.stdout, sys.stderr = old_stdout, old_stderr
2426

2527

28+
def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
29+
if cc in ["cl", "clang-cl"]:
30+
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
31+
cc_cmd += [f"/I{dir}" for dir in include_dirs]
32+
cc_cmd += [f"/Fo{os.path.join(os.path.dirname(out), 'main.obj')}"]
33+
cc_cmd += ["/link"]
34+
cc_cmd += [f"/OUT:{out}"]
35+
cc_cmd += [f"/IMPLIB:{os.path.join(os.path.dirname(out), 'main.lib')}"]
36+
cc_cmd += [f"/PDB:{os.path.join(os.path.dirname(out), 'main.pdb')}"]
37+
cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
38+
cc_cmd += [f'{lib}.lib' for lib in libraries]
39+
else:
40+
cc_cmd = [cc, src, "-O3", "-shared", "-Wno-psabi"]
41+
if os.name != "nt":
42+
cc_cmd += ["-fPIC"]
43+
cc_cmd += [f'-l{lib}' for lib in libraries]
44+
cc_cmd += [f"-L{dir}" for dir in library_dirs]
45+
cc_cmd += [f"-I{dir}" for dir in include_dirs]
46+
cc_cmd += ["-o", out]
47+
48+
return cc_cmd
49+
50+
2651
def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]):
2752
suffix = sysconfig.get_config_var('EXT_SUFFIX')
2853
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
@@ -33,6 +58,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
3358
clang = shutil.which("clang")
3459
gcc = shutil.which("gcc")
3560
cc = gcc if gcc is not None else clang
61+
if platform.system() == "Windows":
62+
cc = "cl"
63+
initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"])
3664
if cc is None:
3765
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
3866
# This function was renamed and made public in Python 3.10
@@ -55,25 +83,24 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
5583
clangpp = shutil.which("clang++")
5684
gxx = shutil.which("g++")
5785
icpx = shutil.which("icpx")
58-
cxx = icpx or clangpp or gxx
86+
cxx = icpx if os.name == "nt" else icpx or clangpp or gxx
5987
if cxx is None:
6088
raise RuntimeError("Failed to find C++ compiler. Please specify via CXX environment variable.")
89+
cc = cxx
6190
import numpy as np
6291
numpy_include_dir = np.get_include()
6392
include_dirs = include_dirs + [numpy_include_dir]
64-
cc_cmd = [cxx]
6593
if icpx is not None:
66-
cc_cmd += ["-fsycl"]
94+
extra_compile_args += ["-fsycl"]
6795
else:
68-
cc_cmd += ["--std=c++17"]
96+
extra_compile_args += ["--std=c++17"]
97+
if os.name == "nt":
98+
library_dirs += [os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs")]
6999
else:
70100
cc_cmd = [cc]
71101

72102
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
73-
cc_cmd += [src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]
74-
cc_cmd += [f'-l{lib}' for lib in libraries]
75-
cc_cmd += [f"-L{dir}" for dir in library_dirs]
76-
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
103+
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
77104
cc_cmd += extra_compile_args
78105

79106
if os.getenv("VERBOSE"):
@@ -90,7 +117,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
90117
language='c',
91118
sources=[src],
92119
include_dirs=include_dirs,
93-
extra_compile_args=extra_compile_args + ['-O3'],
120+
extra_compile_args=extra_compile_args + ['-O3' if "-O3" in cc_cmd else "/O2"],
94121
extra_link_args=extra_link_args,
95122
library_dirs=library_dirs,
96123
libraries=libraries,

0 commit comments

Comments
 (0)