File tree Expand file tree Collapse file tree 2 files changed +38
-6
lines changed
source/module_hamilt_pw/hamilt_pwdft/kernels Expand file tree Collapse file tree 2 files changed +38
-6
lines changed Original file line number Diff line number Diff line change @@ -185,9 +185,25 @@ __global__ void cal_multi_dot(const int npw,
185185 const thrust::complex <FPTYPE>* psi,
186186 FPTYPE* sum)
187187{
188- int idx = threadIdx .x + blockIdx .x * blockDim .x ;
189- if (idx < npw) {
190- atomicAdd (sum, fac * gk1[idx] * gk2[idx] * d_kfac[idx] * thrust::norm (psi[idx]));
188+ __shared__ FPTYPE s_sum[THREADS_PER_BLOCK];
189+ int tid = threadIdx .x + blockIdx .x * blockDim .x ;
190+ int cacheid = threadIdx .x ;
191+ FPTYPE local_sum = 0 ;
192+ while (tid < npw) {
193+ local_sum += fac * gk1[tid] * gk2[tid] * d_kfac[tid] * thrust::norm (psi[tid]);
194+ tid += blockDim .x * gridDim .x ;
195+ }
196+ s_sum[cacheid] = local_sum;
197+ __syncthreads ();
198+
199+ for (int s = blockDim .x / 2 ; s > 0 ; s >>= 1 ) {
200+ if (cacheid < s) {
201+ s_sum[cacheid] += s_sum[cacheid + s];
202+ }
203+ __syncthreads ();
204+ }
205+ if (cacheid == 0 ) {
206+ atomicAdd (sum, s_sum[0 ]);
191207 }
192208}
193209
Original file line number Diff line number Diff line change @@ -342,9 +342,25 @@ __global__ void cal_multi_dot(const int npw,
342342 const thrust::complex <FPTYPE>* psi,
343343 FPTYPE* sum)
344344{
345- int idx = threadIdx .x + blockIdx .x * blockDim .x ;
346- if (idx < npw) {
347- atomicAdd (sum, fac * gk1[idx] * gk2[idx] * d_kfac[idx] * thrust::norm (psi[idx]));
345+ __shared__ FPTYPE s_sum[THREADS_PER_BLOCK];
346+ int tid = threadIdx .x + blockIdx .x * blockDim .x ;
347+ int cacheid = threadIdx .x ;
348+ FPTYPE local_sum = 0 ;
349+ while (tid < npw) {
350+ local_sum += fac * gk1[tid] * gk2[tid] * d_kfac[tid] * thrust::norm (psi[tid]);
351+ tid += blockDim .x * gridDim .x ;
352+ }
353+ s_sum[cacheid] = local_sum;
354+ __syncthreads ();
355+
356+ for (int s = blockDim .x / 2 ; s > 0 ; s >>= 1 ) {
357+ if (cacheid < s) {
358+ s_sum[cacheid] += s_sum[cacheid + s];
359+ }
360+ __syncthreads ();
361+ }
362+ if (cacheid == 0 ) {
363+ atomicAdd (sum, s_sum[0 ]);
348364 }
349365}
350366
You can’t perform that action at this time.
0 commit comments