Skip to content

Commit bccaa45

Browse files
authored
[ROCm][Windows] Enable build with ROCm on Windows (#3883)
Added changes for Windows in LoadHIP.cmake: - set default ROCM_PATH and HIP_PATH for Windows case - update CMAKE_MODULE_PATH with appropriate HIP path - because the find_package call to this module uses the Module mode search - added code for finding HIP version - skipped find hsa-runtime64, rccl packages on Windows case - skipped find rccl, roctx and roctracer libraries on Windows case CMAKE_C_COMPILER and CMAKE_CXX_COMPILER redefinition on Windows case (if we are building with ROCm on Windows we need to use compiler specified with CXX and CC flags instead of cl)
1 parent 318bace commit bccaa45

File tree

2 files changed

+77
-37
lines changed

2 files changed

+77
-37
lines changed

cmake/LoadHIP.cmake

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
set(PYTORCH_FOUND_HIP FALSE)
22

33
if(NOT DEFINED ENV{ROCM_PATH})
4-
set(ROCM_PATH /opt/rocm)
4+
if(UNIX)
5+
set(ROCM_PATH /opt/rocm)
6+
else() # Win32
7+
set(ROCM_PATH C:/opt/rocm)
8+
endif()
59
else()
610
set(ROCM_PATH $ENV{ROCM_PATH})
711
endif()
812

913
# HIP_PATH
1014
if(NOT DEFINED ENV{HIP_PATH})
11-
set(HIP_PATH ${ROCM_PATH}/hip)
15+
if(UNIX)
16+
set(HIP_PATH ${ROCM_PATH}/hip)
17+
else() #Win32
18+
set(HIP_PATH ${ROCM_PATH})
19+
endif()
1220
else()
1321
set(HIP_PATH $ENV{HIP_PATH})
1422
endif()
@@ -129,7 +137,9 @@ else()
129137
endif()
130138

131139
# Add HIP to the CMAKE Module Path
132-
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
140+
# needed because the find_package call to this module uses the Module mode search
141+
# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes
142+
set(CMAKE_MODULE_PATH ${HIP_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
133143

134144
# Disable Asserts In Code (Can't use asserts on HIP stack.)
135145
add_definitions(-DNDEBUG)
@@ -145,29 +155,49 @@ find_package_and_print_version(HIP 1.0)
145155
if(HIP_FOUND)
146156
set(PYTORCH_FOUND_HIP TRUE)
147157

148-
# Find ROCM version for checks
149-
file(READ "${ROCM_PATH}/.info/version-dev" ROCM_VERSION_DEV_RAW)
150-
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
151-
if(ROCM_VERSION_DEV_MATCH)
152-
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
153-
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
154-
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
155-
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
158+
if(UNIX)
159+
set(ROCM_LIB_NAME "ROCM")
160+
else() # Win32
161+
set(ROCM_LIB_NAME "HIP")
162+
endif()
163+
if(UNIX)
164+
# Find ROCM version for checks
165+
file(READ "${ROCM_PATH}/.info/version-dev" ${ROCM_LIB_NAME}_VERSION_DEV_RAW)
166+
else() #Win32
167+
# Find HIP version from hipconfig execution
168+
execute_process(
169+
COMMAND ${ROCM_PATH}/bin/hipconfig.bat --version
170+
OUTPUT_VARIABLE ${ROCM_LIB_NAME}_VERSION_DEV_RAW
171+
OUTPUT_STRIP_TRAILING_WHITESPACE
172+
)
173+
endif()
174+
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ${ROCM_LIB_NAME}_VERSION_DEV_MATCH ${${ROCM_LIB_NAME}_VERSION_DEV_RAW})
175+
if(${ROCM_LIB_NAME}_VERSION_DEV_MATCH)
176+
set(${ROCM_LIB_NAME}_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
177+
set(${ROCM_LIB_NAME}_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
178+
set(${ROCM_LIB_NAME}_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
179+
set(${ROCM_LIB_NAME}_VERSION_DEV "${${ROCM_LIB_NAME}_VERSION_DEV_MAJOR}.${${ROCM_LIB_NAME}_VERSION_DEV_MINOR}.${${ROCM_LIB_NAME}_VERSION_DEV_PATCH}")
180+
endif()
181+
if(UNIX)
182+
message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n")
183+
else() #Win32
184+
message("\n***** HIP version from ${ROCM_PATH}/bin/hipconfig.bat --version ****\n")
185+
endif()
186+
message("${ROCM_LIB_NAME}_VERSION_DEV: ${${ROCM_LIB_NAME}_VERSION_DEV}")
187+
message("${ROCM_LIB_NAME}_VERSION_DEV_MAJOR: ${${ROCM_LIB_NAME}_VERSION_DEV_MAJOR}")
188+
message("${ROCM_LIB_NAME}_VERSION_DEV_MINOR: ${${ROCM_LIB_NAME}_VERSION_DEV_MINOR}")
189+
message("${ROCM_LIB_NAME}_VERSION_DEV_PATCH: ${${ROCM_LIB_NAME}_VERSION_DEV_PATCH}")
190+
191+
if(UNIX)
192+
message("\n***** Library versions from dpkg *****\n")
193+
execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
194+
execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
195+
execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
196+
execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
197+
execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
198+
execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}")
199+
execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
156200
endif()
157-
message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n")
158-
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
159-
message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
160-
message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
161-
message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
162-
163-
message("\n***** Library versions from dpkg *****\n")
164-
execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
165-
execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
166-
execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
167-
execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
168-
execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
169-
execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}")
170-
execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
171201

172202
message("\n***** Library versions from cmake find_package *****\n")
173203

@@ -176,7 +206,6 @@ if(HIP_FOUND)
176206
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
177207

178208
set(hip_DIR ${HIP_PATH}/lib/cmake/hip)
179-
set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
180209
set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs)
181210
set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr)
182211
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)
@@ -186,13 +215,11 @@ if(HIP_FOUND)
186215
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
187216
set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft)
188217
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)
189-
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)
190218
set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim)
191219
set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub)
192220
set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust)
193221

194222
find_package_and_print_version(hip REQUIRED)
195-
find_package_and_print_version(hsa-runtime64 REQUIRED)
196223
find_package_and_print_version(amd_comgr REQUIRED)
197224
find_package_and_print_version(rocrand REQUIRED)
198225
find_package_and_print_version(hiprand REQUIRED)
@@ -203,7 +230,6 @@ if(HIP_FOUND)
203230
find_package_and_print_version(hipfft REQUIRED)
204231
endif()
205232
find_package_and_print_version(hipsparse REQUIRED)
206-
find_package_and_print_version(rccl)
207233
find_package_and_print_version(rocprim REQUIRED)
208234
find_package_and_print_version(hipcub REQUIRED)
209235
find_package_and_print_version(rocthrust REQUIRED)
@@ -223,12 +249,22 @@ if(HIP_FOUND)
223249
# TODO: miopen_LIBRARIES should return fullpath to the library file,
224250
# however currently it's just the lib name
225251
find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib)
226-
# TODO: rccl_LIBRARIES should return fullpath to the library file,
227-
# however currently it's just the lib name
228-
find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
229252
# hiprtc is part of HIP
230253
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib)
231-
# roctx is part of roctracer
232-
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
233-
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)
254+
255+
if(UNIX)
256+
set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
257+
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)
258+
259+
find_package_and_print_version(hsa-runtime64 REQUIRED)
260+
find_package_and_print_version(rccl)
261+
262+
# TODO: rccl_LIBRARIES should return fullpath to the library file,
263+
# however currently it's just the lib name
264+
find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
265+
# roctx is part of roctracer
266+
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
267+
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)
268+
endif()
234269
endif()
270+

tools/setup_helpers/extension.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,13 @@ def build_extension(self, ext):
161161
import sys
162162

163163
python_version = sys.version_info
164+
165+
cxx_compiler = os.environ.get('CXX', 'cl')
166+
c_compiler = os.environ.get('CC', 'cl')
167+
164168
cmake_args += [
165-
"-DCMAKE_C_COMPILER=cl",
166-
"-DCMAKE_CXX_COMPILER=cl",
169+
f"-DCMAKE_C_COMPILER={c_compiler}",
170+
f"-DCMAKE_CXX_COMPILER={cxx_compiler}",
167171
f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}",
168172
]
169173

0 commit comments

Comments
 (0)