Skip to content

Commit e500584

Browse files
authored
Fix low efficiency of stress_kin in DCU (#5636)
* fix stuck in out_chg * fix DCU low efficiency
1 parent 7198ec1 commit e500584

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,25 @@ __global__ void cal_multi_dot(const int npw,
185185
const thrust::complex<FPTYPE>* psi,
186186
FPTYPE* sum)
187187
{
188-
int idx = threadIdx.x + blockIdx.x * blockDim.x;
189-
if (idx < npw) {
190-
atomicAdd(sum, fac * gk1[idx] * gk2[idx] * d_kfac[idx] * thrust::norm(psi[idx]));
188+
__shared__ FPTYPE s_sum[THREADS_PER_BLOCK];
189+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
190+
int cacheid = threadIdx.x;
191+
FPTYPE local_sum = 0;
192+
while (tid < npw) {
193+
local_sum += fac * gk1[tid] * gk2[tid] * d_kfac[tid] * thrust::norm(psi[tid]);
194+
tid += blockDim.x * gridDim.x;
195+
}
196+
s_sum[cacheid] = local_sum;
197+
__syncthreads();
198+
199+
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
200+
if (cacheid < s) {
201+
s_sum[cacheid] += s_sum[cacheid + s];
202+
}
203+
__syncthreads();
204+
}
205+
if (cacheid == 0) {
206+
atomicAdd(sum, s_sum[0]);
191207
}
192208
}
193209

source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,25 @@ __global__ void cal_multi_dot(const int npw,
342342
const thrust::complex<FPTYPE>* psi,
343343
FPTYPE* sum)
344344
{
345-
int idx = threadIdx.x + blockIdx.x * blockDim.x;
346-
if (idx < npw) {
347-
atomicAdd(sum, fac * gk1[idx] * gk2[idx] * d_kfac[idx] * thrust::norm(psi[idx]));
345+
__shared__ FPTYPE s_sum[THREADS_PER_BLOCK];
346+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
347+
int cacheid = threadIdx.x;
348+
FPTYPE local_sum = 0;
349+
while (tid < npw) {
350+
local_sum += fac * gk1[tid] * gk2[tid] * d_kfac[tid] * thrust::norm(psi[tid]);
351+
tid += blockDim.x * gridDim.x;
352+
}
353+
s_sum[cacheid] = local_sum;
354+
__syncthreads();
355+
356+
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
357+
if (cacheid < s) {
358+
s_sum[cacheid] += s_sum[cacheid + s];
359+
}
360+
__syncthreads();
361+
}
362+
if (cacheid == 0) {
363+
atomicAdd(sum, s_sum[0]);
348364
}
349365
}
350366

0 commit comments

Comments
 (0)