@@ -7,6 +7,8 @@ namespace hsolver
77{
88const int warp_size = 32 ;
99const int thread_per_block = 256 ;
10+ #define FULL_MASK 0xffffffff
11+ #define WARP_SIZE 32
1012
1113template <typename Real>
1214__global__ void line_minimize_with_block (
@@ -282,6 +284,37 @@ __global__ void precondition_kernel(
282284 }
283285}
284286
287+ template <typename Real>
288+ __device__ Real warpReduceSum (Real val) {
289+ for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 )
290+ val += __shfl_down_sync (FULL_MASK, val, offset);
291+ return val;
292+ }
293+
294+ template <typename Real>
295+ __device__ Real blockReduceSum (Real val, volatile Real* shared) {
296+ int lane = threadIdx .x % WARP_SIZE;
297+ int wid = threadIdx .x / WARP_SIZE;
298+
299+ val = warpReduceSum (val);
300+
301+ if (lane == 0 )
302+ shared[wid] = val;
303+
304+ __syncthreads ();
305+
306+ Real sum = 0.0 ;
307+ if (wid == 0 ) {
308+ sum = (threadIdx .x < blockDim .x / 32 ) ? shared[lane] : 0.0 ;
309+ sum = warpReduceSum (sum);
310+ if (lane == 0 ) shared[0 ] = sum;
311+ }
312+
313+ __syncthreads ();
314+ return shared[0 ];
315+ }
316+
317+
285318template <typename Real>
286319__global__ void normalize_kernel (
287320 thrust::complex <Real>* psi_iter,
@@ -292,38 +325,19 @@ __global__ void normalize_kernel(
292325{
293326 int m = blockIdx .x ;
294327 int tid = threadIdx .x ;
295- __shared__ Real sum[thread_per_block];
328+ extern __shared__ char s_char[];
329+ Real* shared = reinterpret_cast <Real*>(s_char);
296330
297- sum[tid] = 0.0 ;
331+ Real local_sum = 0.0 ;
298332
299333 // Calculate the sum for normalization
300334 for (int i = tid; i < dim; i += thread_per_block) {
301335 auto val = psi_iter[(nbase + m) * dim + i];
302- sum[tid] += (val * thrust::conj (val)).real ();
336+ local_sum += (val * thrust::conj (val)).real ();
303337 }
304338
305- __syncthreads ();
306-
307- // Parallel reduction in shared memory
308- for (int s = thread_per_block/2 ; s > warp_size; s >>= 1 ) {
309- if (tid < s) {
310- sum[tid] += sum[tid + s];
311- }
312- __syncthreads ();
313- }
314-
315- if (tid < warp_size) {
316- sum[tid] += sum[tid + 32 ]; __syncwarp ();
317- sum[tid] += sum[tid + 16 ]; __syncwarp ();
318- sum[tid] += sum[tid + 8 ]; __syncwarp ();
319- sum[tid] += sum[tid + 4 ]; __syncwarp ();
320- sum[tid] += sum[tid + 2 ]; __syncwarp ();
321- sum[tid] += sum[tid + 1 ]; __syncwarp ();
322- }
323-
324- __syncthreads ();
325-
326- Real norm = sqrt (sum[0 ]);
339+ Real l2_sq = blockReduceSum (local_sum, shared);
340+ Real norm = sqrt (l2_sq);
327341
328342 // Normalize the vector
329343 for (int i = tid; i < dim; i += thread_per_block) {
@@ -452,8 +466,9 @@ void normalize_op<T, base_device::DEVICE_GPU>::operator()(const int& dim,
452466 Real* psi_norm)
453467{
454468 auto psi_complex = reinterpret_cast <thrust::complex <Real>*>(psi_iter);
469+ int sharedMemSize = (thread_per_block / WARP_SIZE) * sizeof (Real);
455470
456- normalize_kernel<Real><<<notconv, thread_per_block>>> (
471+ normalize_kernel<Real><<<notconv, thread_per_block, sharedMemSize, 0 >>> (
457472 psi_complex, psi_norm, dim, nbase, notconv);
458473
459474 cudaCheckOnDebug ();
0 commit comments