Skip to content

Commit ccd2874

Browse files
authored
Feature: Allow directly compiling CUDA version on DCU harware (#5727)
* Initial commit * Modify CMakeLists
1 parent 4891b2e commit ccd2874

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ option(ENABLE_CNPY "Enable cnpy usage." OFF)
4040
option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
4141
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
4242
option(USE_DSP "Enable DSP usage." OFF)
43+
option(USE_CUDA_ON_DCU "Enable CUDA on DCU" OFF)
4344

4445
# enable json support
4546
if(ENABLE_RAPIDJSON)
@@ -126,6 +127,10 @@ if (USE_DSP)
126127
set(ABACUS_BIN_NAME abacus_dsp)
127128
endif()
128129

130+
if (USE_CUDA_ON_DCU)
131+
add_compile_definitions(__CUDA_ON_DCU)
132+
endif()
133+
129134
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
130135

131136
if(ENABLE_COVERAGE)

source/module_base/module_device/device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::str
8686
* @brief for compatibility with __CUDA_ARCH__ 600 and earlier
8787
*
8888
*/
89-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
89+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 && !defined(__CUDA_ON_DCU)
9090
static __inline__ __device__ double atomicAdd(double* address, double val)
9191
{
9292
unsigned long long int* address_as_ull = (unsigned long long int*)address;

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ void cal_force_npw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
893893
int t_num = (npw%t_size) ? (npw/t_size + 1) : (npw/t_size);
894894
dim3 npwgrid(((t_num%THREADS_PER_BLOCK) ? (t_num/THREADS_PER_BLOCK + 1) : (t_num/THREADS_PER_BLOCK)));
895895

896-
cal_force_npw << < npwgrid, THREADS_PER_BLOCK >> > (
896+
cal_force_npw <<< npwgrid, THREADS_PER_BLOCK >>> (
897897
reinterpret_cast<const thrust::complex<FPTYPE>*>(psiv),
898898
gv_x, gv_y, gv_z, rhocgigg_vec, force, pos_x, pos_y, pos_z,
899899
npw, omega, tpiba

0 commit comments

Comments
 (0)