Skip to content

Commit 7945a36

Browse files
committed
Perf: use warp reduce instead of shared memory for better efficiency
Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd.
1 parent 968e537 commit 7945a36

File tree

1 file changed

+41
-26
lines changed

1 file changed

+41
-26
lines changed

source/source_hsolver/kernels/cuda/bpcg_kernel_op.cu

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace hsolver
77
{
88
const int warp_size = 32;
99
const int thread_per_block = 256;
10+
#define FULL_MASK 0xffffffff
11+
#define WARP_SIZE 32
1012

1113
template <typename Real>
1214
__global__ void line_minimize_with_block(
@@ -282,6 +284,37 @@ __global__ void precondition_kernel(
282284
}
283285
}
284286

287+
template <typename Real>
288+
__device__ Real warpReduceSum(Real val) {
289+
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
290+
val += __shfl_down_sync(FULL_MASK, val, offset);
291+
return val;
292+
}
293+
294+
template <typename Real>
295+
__device__ Real blockReduceSum(Real val, volatile Real* shared) {
296+
int lane = threadIdx.x % WARP_SIZE;
297+
int wid = threadIdx.x / WARP_SIZE;
298+
299+
val = warpReduceSum(val);
300+
301+
if (lane == 0)
302+
shared[wid] = val;
303+
304+
__syncthreads();
305+
306+
Real sum = 0.0;
307+
if (wid == 0) {
308+
sum = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0;
309+
sum = warpReduceSum(sum);
310+
if (lane == 0) shared[0] = sum;
311+
}
312+
313+
__syncthreads();
314+
return shared[0];
315+
}
316+
317+
285318
template <typename Real>
286319
__global__ void normalize_kernel(
287320
thrust::complex<Real>* psi_iter,
@@ -292,38 +325,19 @@ __global__ void normalize_kernel(
292325
{
293326
int m = blockIdx.x;
294327
int tid = threadIdx.x;
295-
__shared__ Real sum[thread_per_block];
328+
extern __shared__ char s_char[];
329+
Real* shared = reinterpret_cast<Real*>(s_char);
296330

297-
sum[tid] = 0.0;
331+
Real local_sum = 0.0;
298332

299333
// Calculate the sum for normalization
300334
for (int i = tid; i < dim; i += thread_per_block) {
301335
auto val = psi_iter[(nbase + m) * dim + i];
302-
sum[tid] += (val * thrust::conj(val)).real();
336+
local_sum += (val * thrust::conj(val)).real();
303337
}
304338

305-
__syncthreads();
306-
307-
// Parallel reduction in shared memory
308-
for (int s = thread_per_block/2; s > warp_size; s >>= 1) {
309-
if (tid < s) {
310-
sum[tid] += sum[tid + s];
311-
}
312-
__syncthreads();
313-
}
314-
315-
if (tid < warp_size) {
316-
sum[tid] += sum[tid + 32]; __syncwarp();
317-
sum[tid] += sum[tid + 16]; __syncwarp();
318-
sum[tid] += sum[tid + 8]; __syncwarp();
319-
sum[tid] += sum[tid + 4]; __syncwarp();
320-
sum[tid] += sum[tid + 2]; __syncwarp();
321-
sum[tid] += sum[tid + 1]; __syncwarp();
322-
}
323-
324-
__syncthreads();
325-
326-
Real norm = sqrt(sum[0]);
339+
Real l2_sq = blockReduceSum(local_sum, shared);
340+
Real norm = sqrt(l2_sq);
327341

328342
// Normalize the vector
329343
for (int i = tid; i < dim; i += thread_per_block) {
@@ -452,8 +466,9 @@ void normalize_op<T, base_device::DEVICE_GPU>::operator()(const int& dim,
452466
Real* psi_norm)
453467
{
454468
auto psi_complex = reinterpret_cast<thrust::complex<Real>*>(psi_iter);
469+
int sharedMemSize = (thread_per_block / WARP_SIZE) * sizeof(Real);
455470

456-
normalize_kernel<Real><<<notconv, thread_per_block>>>(
471+
normalize_kernel<Real><<<notconv, thread_per_block, sharedMemSize, 0>>>(
457472
psi_complex, psi_norm, dim, nbase, notconv);
458473

459474
cudaCheckOnDebug();

0 commit comments

Comments
 (0)