diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu index 36e0aac37a..997827d669 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu @@ -185,9 +185,25 @@ __global__ void cal_multi_dot(const int npw, const thrust::complex* psi, FPTYPE* sum) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < npw) { - atomicAdd(sum, fac * gk1[idx] * gk2[idx] * d_kfac[idx] * thrust::norm(psi[idx])); + __shared__ FPTYPE s_sum[THREADS_PER_BLOCK]; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int cacheid = threadIdx.x; + FPTYPE local_sum = 0; + while (tid < npw) { + local_sum += fac * gk1[tid] * gk2[tid] * d_kfac[tid] * thrust::norm(psi[tid]); + tid += blockDim.x * gridDim.x; + } + s_sum[cacheid] = local_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (cacheid < s) { + s_sum[cacheid] += s_sum[cacheid + s]; + } + __syncthreads(); + } + if (cacheid == 0) { + atomicAdd(sum, s_sum[0]); } } diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu index cd26e312a5..a5f8e553af 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu @@ -342,9 +342,25 @@ __global__ void cal_multi_dot(const int npw, const thrust::complex* psi, FPTYPE* sum) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < npw) { - atomicAdd(sum, fac * gk1[idx] * gk2[idx] * d_kfac[idx] * thrust::norm(psi[idx])); + __shared__ FPTYPE s_sum[THREADS_PER_BLOCK]; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int cacheid = threadIdx.x; + FPTYPE local_sum = 0; + while (tid < npw) { + local_sum += fac * gk1[tid] * gk2[tid] * d_kfac[tid] * thrust::norm(psi[tid]); + tid += blockDim.x * gridDim.x; + } + s_sum[cacheid] = local_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (cacheid < s) { + s_sum[cacheid] += s_sum[cacheid + s]; + } + __syncthreads(); + } + if (cacheid == 0) { + atomicAdd(sum, s_sum[0]); } }