Skip to content

Commit 0f07afa

Browse files
authored
merge cuda and rocm files (#2844)
Merge `source/lib/src/cuda` and `source/lib/src/rocm` into `source/lib/src/gpu`. - Define macros `gpuGetLastError`, `gpuDeviceSynchronize`, `gpuMemcpy`, `gpuMemcpyDeviceToHost`, `gpuMemcpyHostToDevice`, and `gpuMemset` to make them available for both CUDA and ROCm. - Use `<<< >>> syntax` for both CUDA and ROCm. Per ROCm/hip@cf78d85, it has been supported in HIP since 2018. - Fix several int const numbers that should be double or float. - For tabulate: - Fix `WARP_SIZE` for ROCm. Per pytorch/pytorch#64302, WARP_SIZE can be 32 or 64, so it should not be hardcoded to 64. - Add `GpuShuffleSync`. Per ROCm/hip#1491, `__shfl_sync` is not supported by HIP. - After merging the code, #1274 should also work for ROCm. - Use the same `ii` for #830 and #2357. Although both of them work, `ii` has different meanings in these two PRs, but now it should be the same. - However, `ii` in `tabulate_fusion_se_a_fifth_order_polynomial` (rocm) added by #2532 is wrong. After merging the codes, it should be corrected. - Optimization in #830 was not applied to ROCm. - `__syncwarp` is not supported by ROCm. - After merging the code, #2661 will be applied to ROCm. Although TF ROCm stream is still blocking (https://github.com/tensorflow/tensorflow/blob/9d1262082e761cd85d6726bcbdfdef331d6d72c6/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc#L566), we don't know whether it will change to non-blocking. - There are several other differences between CUDA and ROCm. --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 544875e commit 0f07afa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+490
-3878
lines changed

.github/labeler.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ Python:
55
Docs: doc/**/*
66
Examples: examples/**/*
77
Core: source/lib/**/*
8-
CUDA: source/lib/src/cuda/**/*
9-
ROCM: source/lib/src/rocm/**/*
8+
CUDA: source/lib/src/gpu/**/*
9+
ROCM: source/lib/src/gpu/**/*
1010
OP: source/op/**/*
1111
C++: source/api_cc/**/*
1212
C: source/api_c/**/*

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
[submodule "source/lib/src/cuda/cub"]
2-
path = source/lib/src/cuda/cub
1+
[submodule "source/lib/src/gpu/cub"]
2+
path = source/lib/src/gpu/cub
33
url = https://github.com/NVIDIA/cub.git

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ repos:
5353
rev: v16.0.6
5454
hooks:
5555
- id: clang-format
56-
exclude: ^source/3rdparty|source/lib/src/cuda/cudart/.+\.inc
56+
exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc
5757
# CSS
5858
- repo: https://github.com/pre-commit/mirrors-csslint
5959
rev: v1.0.5
@@ -83,7 +83,7 @@ repos:
8383
- --comment-style
8484
- //
8585
- --no-extra-eol
86-
exclude: ^source/3rdparty|source/lib/src/cuda/cudart/.+\.inc
86+
exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc
8787
# CSS
8888
- id: insert-license
8989
files: \.(css|scss)$

doc/install/install-from-source.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ One may set the following environment variables before executing `pip`:
7474
| Environment variables | Allowed value | Default value | Usage |
7575
| --------------------- | ---------------------- | ------------- | -------------------------- |
7676
| DP_VARIANT | `cpu`, `cuda`, `rocm` | `cpu` | Build CPU variant or GPU variant with CUDA or ROCM support. |
77-
| CUDAToolkit_ROOT | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 7.0 or later is supported. NVCC is required. |
77+
| CUDAToolkit_ROOT | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is required. |
7878
| ROCM_ROOT | Path | Detected automatically | The path to the ROCM toolkit directory. |
7979
| TENSORFLOW_ROOT | Path | Detected automatically | The path to TensorFlow Python library. By default the installer only finds TensorFlow under user site-package directory (`site.getusersitepackages()`) or system site-package directory (`sysconfig.get_path("purelib")`) due to limitation of [PEP-517](https://peps.python.org/pep-0517/). If not found, the latest TensorFlow (or the environment variable `TENSORFLOW_VERSION` if given) from PyPI will be built against.|
8080
| DP_ENABLE_NATIVE_OPTIMIZATION | 0, 1 | 0 | Enable compilation optimization for the native machine's CPU type. Do not enable it if generated code will run on different CPUs. |
@@ -188,7 +188,7 @@ One may add the following arguments to `cmake`:
188188
| -DTENSORFLOW_ROOT=&lt;value&gt; | Path | - | The Path to TensorFlow's C++ interface. |
189189
| -DCMAKE_INSTALL_PREFIX=&lt;value&gt; | Path | - | The Path where DeePMD-kit will be installed. |
190190
| -DUSE_CUDA_TOOLKIT=&lt;value&gt; | `TRUE` or `FALSE` | `FALSE` | If `TRUE`, Build GPU support with CUDA toolkit. |
191-
| -DCUDAToolkit_ROOT=&lt;value&gt; | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 7.0 or later is supported. NVCC is required. |
191+
| -DCUDAToolkit_ROOT=&lt;value&gt; | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is required. |
192192
| -DUSE_ROCM_TOOLKIT=&lt;value&gt; | `TRUE` or `FALSE` | `FALSE` | If `TRUE`, Build GPU support with ROCM toolkit. |
193193
| -DCMAKE_HIP_COMPILER_ROCM_ROOT=&lt;value&gt; | Path | Detected automatically | The path to the ROCM toolkit directory. |
194194
| -DLAMMPS_SOURCE_ROOT=&lt;value&gt; | Path | - | Only neccessary for LAMMPS plugin mode. The path to the [LAMMPS source code](install-lammps.md). LAMMPS 8Apr2021 or later is supported. If not assigned, the plugin mode will not be enabled. |

source/lib/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ target_include_directories(
1111

1212
if(USE_CUDA_TOOLKIT)
1313
add_definitions("-DGOOGLE_CUDA")
14-
add_subdirectory(src/cuda)
14+
add_subdirectory(src/gpu)
1515
set(EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_cuda)
1616
target_link_libraries(${libname} INTERFACE deepmd_dyn_cudart ${EXTRA_LIBS})
1717
# gpu_cuda.h
@@ -22,7 +22,7 @@ endif()
2222

2323
if(USE_ROCM_TOOLKIT)
2424
add_definitions("-DTENSORFLOW_USE_ROCM")
25-
add_subdirectory(src/rocm)
25+
add_subdirectory(src/gpu)
2626
set(EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_rocm)
2727
target_link_libraries(${libname} INTERFACE ${ROCM_LIBRARIES} ${EXTRA_LIBS})
2828
# gpu_rocm.h

source/lib/include/gpu_cuda.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88

99
#include "errors.h"
1010

11+
#define gpuGetLastError cudaGetLastError
12+
#define gpuDeviceSynchronize cudaDeviceSynchronize
13+
#define gpuMemcpy cudaMemcpy
14+
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
15+
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
16+
#define gpuMemset cudaMemset
17+
1118
#define GPU_MAX_NBOR_SIZE 4096
1219
#define DPErrcheck(res) \
1320
{ DPAssert((res), __FILE__, __LINE__); }

source/lib/include/gpu_rocm.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
#define GPU_MAX_NBOR_SIZE 4096
1313

14+
#define gpuGetLastError hipGetLastError
15+
#define gpuDeviceSynchronize hipDeviceSynchronize
16+
#define gpuMemcpy hipMemcpy
17+
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
18+
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
19+
#define gpuMemset hipMemset
20+
1421
#define DPErrcheck(res) \
1522
{ DPAssert((res), __FILE__, __LINE__); }
1623
inline void DPAssert(hipError_t code,

source/lib/src/cuda/CMakeLists.txt

Lines changed: 0 additions & 60 deletions
This file was deleted.

source/lib/src/gpu/CMakeLists.txt

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
if(USE_CUDA_TOOLKIT)
2+
# required cmake version 3.23: CMAKE_CUDA_ARCHITECTURES all
3+
cmake_minimum_required(VERSION 3.23)
4+
# project name
5+
project(deepmd_op_cuda)
6+
set(GPU_LIB_NAME deepmd_op_cuda)
7+
8+
set(CMAKE_CUDA_ARCHITECTURES all)
9+
enable_language(CUDA)
10+
set(CMAKE_CUDA_STANDARD 11)
11+
add_compile_definitions(
12+
"$<$<COMPILE_LANGUAGE:CUDA>:_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI}>")
13+
14+
find_package(CUDAToolkit REQUIRED)
15+
16+
# take dynamic open cudart library replace of static one so it's not required
17+
# when using CPUs
18+
add_subdirectory(cudart)
19+
20+
# nvcc -o libdeepmd_op_cuda.so -I/usr/local/cub-1.8.0 -rdc=true
21+
# -DHIGH_PREC=true -gencode arch=compute_61,code=sm_61 -shared -Xcompiler
22+
# -fPIC deepmd_op.cu -L/usr/local/cuda/lib64 -lcudadevrt very important here!
23+
# Include path to cub. for searching device compute capability,
24+
# https://developer.nvidia.com/cuda-gpus
25+
26+
# cub has been included in CUDA Toolkit 11, we do not need to include it any
27+
# more see https://github.com/NVIDIA/cub
28+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS "11")
29+
include_directories(cub)
30+
endif()
31+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS "9")
32+
message(FATAL_ERROR "CUDA version must be >= 9.0")
33+
endif()
34+
35+
message(STATUS "NVCC version is " ${CMAKE_CUDA_COMPILER_VERSION})
36+
37+
# arch will be configured by CMAKE_CUDA_ARCHITECTURES
38+
set(CMAKE_CUDA_FLAGS
39+
"${CMAKE_CUDA_FLAGS} -DCUB_IGNORE_DEPRECATED_CPP_DIALECT -DCUB_IGNORE_DEPRECATED_CPP_DIALECT"
40+
)
41+
42+
file(GLOB SOURCE_FILES "*.cu")
43+
44+
add_library(${GPU_LIB_NAME} SHARED ${SOURCE_FILES})
45+
target_link_libraries(${GPU_LIB_NAME} PRIVATE deepmd_dyn_cudart)
46+
47+
elseif(USE_ROCM_TOOLKIT)
48+
49+
# required cmake version
50+
cmake_minimum_required(VERSION 3.21)
51+
# project name
52+
project(deepmd_op_rocm)
53+
set(GPU_LIB_NAME deepmd_op_rocm)
54+
set(CMAKE_LINK_WHAT_YOU_USE TRUE)
55+
56+
# set c++ version c++11
57+
set(CMAKE_CXX_STANDARD 14)
58+
set(CMAKE_HIP_STANDARD 14)
59+
add_definitions("-DCUB_IGNORE_DEPRECATED_CPP_DIALECT")
60+
add_definitions("-DCUB_IGNORE_DEPRECATED_CPP_DIALECT")
61+
62+
message(STATUS "HIP major version is " ${HIP_VERSION_MAJOR})
63+
64+
set(HIP_HIPCC_FLAGS -fno-gpu-rdc; -fPIC --std=c++14 ${HIP_HIPCC_FLAGS}
65+
)# --amdgpu-target=gfx906
66+
if(HIP_VERSION VERSION_LESS 3.5.1)
67+
set(HIP_HIPCC_FLAGS -hc; ${HIP_HIPCC_FLAGS})
68+
endif()
69+
70+
file(GLOB SOURCE_FILES "*.cu")
71+
72+
hip_add_library(${GPU_LIB_NAME} SHARED ${SOURCE_FILES})
73+
74+
endif()
75+
76+
target_include_directories(
77+
${GPU_LIB_NAME}
78+
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../include/>
79+
$<INSTALL_INTERFACE:include>)
80+
target_precompile_headers(${GPU_LIB_NAME} PUBLIC [["device.h"]])
81+
if(APPLE)
82+
set_target_properties(${GPU_LIB_NAME} PROPERTIES INSTALL_RPATH @loader_path)
83+
else()
84+
set_target_properties(${GPU_LIB_NAME} PROPERTIES INSTALL_RPATH "$ORIGIN")
85+
endif()
86+
87+
if(BUILD_CPP_IF AND NOT BUILD_PY_IF)
88+
install(
89+
TARGETS ${GPU_LIB_NAME}
90+
EXPORT ${CMAKE_PROJECT_NAME}Targets
91+
DESTINATION lib/)
92+
endif(BUILD_CPP_IF AND NOT BUILD_PY_IF)
93+
if(BUILD_PY_IF)
94+
install(TARGETS ${GPU_LIB_NAME} DESTINATION deepmd/lib/)
95+
endif(BUILD_PY_IF)

source/lib/src/cuda/coord.cu renamed to source/lib/src/gpu/coord.cu

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -266,21 +266,21 @@ void compute_int_data(int *int_data,
266266
_fill_idx_cellmap<<<nblock_loc, TPB>>>(idx_cellmap, idx_cellmap_noshift, in_c,
267267
rec_boxt, nat_stt, nat_end, ext_stt,
268268
ext_end, nloc);
269-
DPErrcheck(cudaGetLastError());
270-
DPErrcheck(cudaDeviceSynchronize());
269+
DPErrcheck(gpuGetLastError());
270+
DPErrcheck(gpuDeviceSynchronize());
271271

272272
const int nblock_loc_cellnum = (loc_cellnum + TPB - 1) / TPB;
273273
_fill_loc_cellnum_map<<<nblock_loc_cellnum, TPB>>>(
274274
temp_idx_order, loc_cellnum_map, idx_cellmap_noshift, nloc, loc_cellnum);
275-
DPErrcheck(cudaGetLastError());
276-
DPErrcheck(cudaDeviceSynchronize());
275+
DPErrcheck(gpuGetLastError());
276+
DPErrcheck(gpuDeviceSynchronize());
277277

278278
const int nblock_total_cellnum = (total_cellnum + TPB - 1) / TPB;
279279
_fill_total_cellnum_map<<<nblock_total_cellnum, TPB>>>(
280280
total_cellnum_map, mask_cellnum_map, cell_map, cell_shift_map, nat_stt,
281281
nat_end, ext_stt, ext_end, loc_cellnum_map, total_cellnum);
282-
DPErrcheck(cudaGetLastError());
283-
DPErrcheck(cudaDeviceSynchronize());
282+
DPErrcheck(gpuGetLastError());
283+
DPErrcheck(gpuDeviceSynchronize());
284284
}
285285

286286
void build_loc_clist(int *int_data,
@@ -297,8 +297,8 @@ void build_loc_clist(int *int_data,
297297
total_cellnum * 3 + loc_cellnum + 1 + total_cellnum + 1;
298298
_build_loc_clist<<<nblock, TPB>>>(loc_clist, idx_cellmap_noshift,
299299
temp_idx_order, sec_loc_cellnum_map, nloc);
300-
DPErrcheck(cudaGetLastError());
301-
DPErrcheck(cudaDeviceSynchronize());
300+
DPErrcheck(gpuGetLastError());
301+
DPErrcheck(gpuDeviceSynchronize());
302302
}
303303

304304
template <typename FPTYPE>
@@ -326,23 +326,23 @@ void copy_coord(FPTYPE *out_c,
326326
cell_shift_map, sec_loc_cellnum_map,
327327
sec_total_cellnum_map, loc_clist, nloc, nall,
328328
total_cellnum, boxt, rec_boxt);
329-
DPErrcheck(cudaGetLastError());
330-
DPErrcheck(cudaDeviceSynchronize());
329+
DPErrcheck(gpuGetLastError());
330+
DPErrcheck(gpuDeviceSynchronize());
331331
}
332332

333333
namespace deepmd {
334334
template <typename FPTYPE>
335335
void normalize_coord_gpu(FPTYPE *coord,
336336
const int natom,
337337
const Region<FPTYPE> &region) {
338-
DPErrcheck(cudaGetLastError());
339-
DPErrcheck(cudaDeviceSynchronize());
338+
DPErrcheck(gpuGetLastError());
339+
DPErrcheck(gpuDeviceSynchronize());
340340
const FPTYPE *boxt = region.boxt;
341341
const FPTYPE *rec_boxt = region.rec_boxt;
342342
const int nblock = (natom + TPB - 1) / TPB;
343343
normalize_one<<<nblock, TPB>>>(coord, boxt, rec_boxt, natom);
344-
DPErrcheck(cudaGetLastError());
345-
DPErrcheck(cudaDeviceSynchronize());
344+
DPErrcheck(gpuGetLastError());
345+
DPErrcheck(gpuDeviceSynchronize());
346346
}
347347

348348
// int_data(temp cuda
@@ -362,16 +362,17 @@ int copy_coord_gpu(FPTYPE *out_c,
362362
const int &total_cellnum,
363363
const int *cell_info,
364364
const Region<FPTYPE> &region) {
365-
DPErrcheck(cudaGetLastError());
366-
DPErrcheck(cudaDeviceSynchronize());
365+
DPErrcheck(gpuGetLastError());
366+
DPErrcheck(gpuDeviceSynchronize());
367367
compute_int_data(int_data, in_c, cell_info, region, nloc, loc_cellnum,
368368
total_cellnum);
369369
int *int_data_cpu = new int
370370
[loc_cellnum + 2 * total_cellnum + loc_cellnum + 1 + total_cellnum +
371371
1]; // loc_cellnum_map,total_cellnum_map,mask_cellnum_map,sec_loc_cellnum_map,sec_total_cellnum_map
372-
DPErrcheck(cudaMemcpy(int_data_cpu, int_data + 3 * nloc,
373-
sizeof(int) * (loc_cellnum + 2 * total_cellnum),
374-
cudaMemcpyDeviceToHost));
372+
DPErrcheck(gpuMemcpy(int_data_cpu, int_data + 3 * nloc,
373+
sizeof(int) * (loc_cellnum + 2 * total_cellnum),
374+
gpuMemcpyDeviceToHost));
375+
DPErrcheck(gpuGetLastError());
375376
int *loc_cellnum_map = int_data_cpu;
376377
int *total_cellnum_map = loc_cellnum_map + loc_cellnum;
377378
int *mask_cellnum_map = total_cellnum_map + total_cellnum;
@@ -397,11 +398,12 @@ int copy_coord_gpu(FPTYPE *out_c,
397398
// size of the output arrays is not large enough
398399
return 1;
399400
} else {
400-
DPErrcheck(cudaMemcpy(int_data + nloc * 3 + loc_cellnum +
401-
total_cellnum * 3 + total_cellnum * 3,
402-
sec_loc_cellnum_map,
403-
sizeof(int) * (loc_cellnum + 1 + total_cellnum + 1),
404-
cudaMemcpyHostToDevice));
401+
DPErrcheck(gpuMemcpy(int_data + nloc * 3 + loc_cellnum + total_cellnum * 3 +
402+
total_cellnum * 3,
403+
sec_loc_cellnum_map,
404+
sizeof(int) * (loc_cellnum + 1 + total_cellnum + 1),
405+
gpuMemcpyHostToDevice));
406+
DPErrcheck(gpuGetLastError());
405407
delete[] int_data_cpu;
406408
build_loc_clist(int_data, nloc, loc_cellnum, total_cellnum);
407409
copy_coord(out_c, out_t, mapping, int_data, in_c, in_t, nloc, *nall,

0 commit comments

Comments
 (0)