Skip to content

Commit dcb8ff0

Browse files
authored
[Proton] Enable build on windows (#5533)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 696af36 commit dcb8ff0

File tree

20 files changed

+275
-61
lines changed

20 files changed

+275
-61
lines changed

CMakeLists.txt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,8 @@ include(CTest)
1616
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
1717

1818
# Options
19-
if(WIN32)
20-
set(DEFAULT_BUILD_PROTON OFF)
21-
else()
22-
set(DEFAULT_BUILD_PROTON ON)
23-
endif()
24-
25-
# Define the option with the determined default value
26-
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON})
2719
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
20+
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2821
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2922
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
3023
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")

lib/Instrumentation/PrintLoadStoreMemSpaces.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,6 @@ extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() {
101101
return getPassPluginInfo();
102102
}
103103

104-
#ifdef WIN32
104+
#if defined(_WIN32)
105105
#pragma comment(linker, "/export:llvmGetPassPluginInfo")
106106
#endif

test/lib/Instrumentation/GPUHello.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ llvmGetPassPluginInfo() {
7575
return getPassPluginInfo();
7676
}
7777

78-
#ifdef WIN32
78+
#if defined(_WIN32)
7979
#pragma comment(linker, "/export:llvmGetPassPluginInfo")
8080
#endif

third_party/amd/backend/include/hsa/hsa_ext_amd.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@
6161
#define HSA_AMD_INTERFACE_VERSION_MAJOR 1
6262
#define HSA_AMD_INTERFACE_VERSION_MINOR 6
6363

64+
#if defined(_WIN32)
65+
#define ALWAYS_INLINE __forceinline
66+
#else
67+
#define ALWAYS_INLINE __attribute__((always_inline)) inline
68+
#endif
69+
6470
#ifdef __cplusplus
6571
extern "C" {
6672
#endif
@@ -73,8 +79,7 @@ extern "C" {
7379
* @brief Macro to use to determine that a flag is set when querying flags within uint8_t[8]
7480
* types
7581
*/
76-
static __inline__ __attribute__((always_inline)) bool hsa_flag_isset64(uint8_t* value,
77-
uint32_t bit) {
82+
static ALWAYS_INLINE bool hsa_flag_isset64(uint8_t* value, uint32_t bit) {
7883
unsigned int index = bit / 8;
7984
unsigned int subBit = bit % 8;
8085
return ((uint8_t*)value)[index] & (1 << subBit);

third_party/amd/include/hipblas_instance.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define TRITON_HIPBLAS_INSTANCE_H
33

44
#include "hipblas_types.h"
5-
#ifdef WIN32
5+
#if defined(_WIN32)
66
#define WIN32_LEAN_AND_MEAN
77
#define NOMINMAX
88
#include <windows.h>
@@ -84,7 +84,7 @@ class HipblasLtInstance {
8484
hipblasLtMatmulPreference_t preference = NULL;
8585

8686
void loadHipBlasDylib() {
87-
#ifdef WIN32
87+
#if defined(_WIN32)
8888
assert(0 && "Not implemented on Windows");
8989
#else
9090
if (dylibHandle == nullptr) {
@@ -137,7 +137,7 @@ class HipblasLtInstance {
137137
}
138138

139139
void unloadHipBlasDylib() {
140-
#ifdef WIN32
140+
#if defined(_WIN32)
141141
assert(0 && "Not implemented on Windows");
142142
#else
143143
dlclose(dylibHandle);

third_party/intel/backend/driver.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ struct BuildFlags {
169169

170170
sycl::context get_default_context(const sycl::device &sycl_device) {
171171
const auto &platform = sycl_device.get_platform();
172-
#ifdef WIN32
172+
#if defined(_WIN32)
173173
sycl::context ctx;
174174
try {
175175
#if __SYCL_COMPILER_VERSION >= 20250604

third_party/intel/backend/proton/include/pti/pti_export.h

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,46 @@
66
# define PTI_EXPORT
77
# define PTI_NO_EXPORT
88
#else
9-
# ifndef PTI_EXPORT
10-
# ifdef pti_EXPORTS
11-
/* We are building this library */
12-
# define PTI_EXPORT __attribute__((visibility("default")))
13-
# else
14-
/* We are using this library */
15-
# define PTI_EXPORT __attribute__((visibility("default")))
9+
10+
# ifdef WIN32
11+
/* Windows (MSVC/MinGW) */
12+
# ifndef PTI_EXPORT
13+
# ifdef pti_EXPORTS
14+
/* We are building this library */
15+
# define PTI_EXPORT __declspec(dllexport)
16+
# else
17+
/* We are using this library */
18+
# define PTI_EXPORT __declspec(dllimport)
19+
# endif
1620
# endif
17-
# endif
1821

19-
# ifndef PTI_NO_EXPORT
20-
# define PTI_NO_EXPORT __attribute__((visibility("hidden")))
21-
# endif
22-
#endif
22+
# ifndef PTI_NO_EXPORT
23+
# define PTI_NO_EXPORT
24+
# endif
25+
26+
# else
27+
28+
/* Linux / Unix — GCC/Clang visibility */
29+
# ifndef PTI_EXPORT
30+
# ifdef pti_EXPORTS
31+
/* We are building this library */
32+
# define PTI_EXPORT __attribute__((visibility("default")))
33+
# else
34+
/* We are using this library */
35+
# define PTI_EXPORT __attribute__((visibility("default")))
36+
# endif
37+
# endif
38+
39+
# ifndef PTI_NO_EXPORT
40+
# define PTI_NO_EXPORT __attribute__((visibility("hidden")))
41+
# endif
42+
43+
# endif /* WIN32 */
44+
45+
#endif /* PTI_STATIC_DEFINE */
2346

2447
#ifndef PTI_DEPRECATED
25-
# define PTI_DEPRECATED
48+
# define PTI_DEPRECATED
2649
#endif
2750

2851
#ifndef PTI_DEPRECATED_EXPORT

third_party/nvidia/backend/driver.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "cuda.h"
2-
#ifdef WIN32
2+
#if defined(_WIN32)
33
#define WIN32_LEAN_AND_MEAN
44
#define NOMINMAX
55
#include <windows.h>
@@ -180,7 +180,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
180180
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
181181
CUtensorMapFloatOOBfill oobFill);
182182

183-
#ifdef WIN32
183+
#if defined(_WIN32)
184184
#define defineGetFunctionHandle(name, symbolName) \
185185
static symbolName##_t name() { \
186186
/* Open the shared library */ \
@@ -196,6 +196,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
196196
if (err) { \
197197
PyErr_SetString(PyExc_RuntimeError, \
198198
"Failed to retrieve " #symbolName " from nvcuda.dll"); \
199+
FreeLibrary(handle); \
199200
return NULL; \
200201
} \
201202
return funcHandle; \

third_party/nvidia/include/cublas_instance.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define TRITON_CUBLAS_INSTANCE_H
33

44
#include "cublas_types.h"
5-
#ifdef WIN32
5+
#if defined(_WIN32)
66
#define WIN32_LEAN_AND_MEAN
77
#define NOMINMAX
88
#include <windows.h>
@@ -70,7 +70,7 @@ class CublasLtInstance {
7070
cublasLtMatmulPreference_t preference = NULL;
7171

7272
void loadCublasDylib() {
73-
#ifdef WIN32
73+
#if defined(_WIN32)
7474
assert(0 && "Not implemented on Windows");
7575
#else
7676
if (dylibHandle == nullptr) {
@@ -121,7 +121,7 @@ class CublasLtInstance {
121121
}
122122

123123
void unloadCublasDylib() {
124-
#ifdef WIN32
124+
#if defined(_WIN32)
125125
assert(0 && "Not implemented on Windows");
126126
#else
127127
dlclose(dylibHandle);

third_party/proton/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ add_library(proton SHARED ${_proton_obj_sources})
8383

8484
target_link_libraries(proton PRIVATE Python3::Module)
8585

86+
if(WIN32)
87+
set_target_properties(proton PROPERTIES SUFFIX ".pyd")
88+
set_target_properties(proton PROPERTIES PREFIX "lib")
89+
if(DEFINED TRITON_PYD_PATH)
90+
string(TOUPPER "${CMAKE_BUILD_TYPE}" UPPER_CMAKE_BUILD_TYPE)
91+
set_target_properties(proton PROPERTIES
92+
RUNTIME_OUTPUT_DIRECTORY_${UPPER_CMAKE_BUILD_TYPE}
93+
"${TRITON_PYD_PATH}")
94+
endif(DEFINED TRITON_PYD_PATH)
95+
endif()
96+
8697
# Apply any macOS linker flags or extra link options
8798
if(PROTON_PYTHON_LDFLAGS)
8899
target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS})

0 commit comments

Comments
 (0)