|
2 | 2 | #include "device.h" |
3 | 3 | #include "neighbor_list.h" |
4 | 4 |
|
| 5 | +#include "hipcub/hipcub.hpp" |
| 6 | +// A stateful callback functor that maintains a running prefix to be applied |
| 7 | +// during consecutive scan operations. |
| 8 | +struct parallel_prefix_scan_op |
| 9 | +{ |
| 10 | + // Running prefix |
| 11 | + int running_total; |
| 12 | + // Constructor |
| 13 | + __device__ parallel_prefix_scan_op(int running_total) : running_total(running_total) {} |
| 14 | + // Callback operator to be entered by the first warp of threads in the block. |
| 15 | + // Thread-0 is responsible for returning a value for seeding the block-wide scan. |
| 16 | + __device__ int operator()(int block_aggregate) |
| 17 | + { |
| 18 | + int old_prefix = running_total; |
| 19 | + running_total += block_aggregate; |
| 20 | + return old_prefix; |
| 21 | + } |
| 22 | +}; |
| 23 | + |
| 24 | +template < |
| 25 | + int THREADS_PER_BLOCK> |
| 26 | +__global__ void parallel_prefix_scan( |
| 27 | + int * numneigh, |
| 28 | + int * nei_order, |
| 29 | + const int * temp_nlist, |
| 30 | + const int mem_size, |
| 31 | + const int nloc, |
| 32 | + const int nall |
| 33 | +) |
| 34 | +{ |
| 35 | + // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread |
| 36 | + typedef hipcub::BlockScan<int, THREADS_PER_BLOCK> BlockScan; |
| 37 | + // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan |
| 38 | + __shared__ typename BlockScan::TempStorage temp_storage; |
| 39 | + |
| 40 | + // Initialize running total |
| 41 | + parallel_prefix_scan_op prefix_op(0); |
| 42 | + |
| 43 | + // Have the block iterate over segments of items |
| 44 | + for (int ii = threadIdx.x; ii < nall; ii += THREADS_PER_BLOCK) |
| 45 | + { |
| 46 | + int block_offset = blockIdx.x * mem_size; |
| 47 | + // Load a segment of consecutive items that are blocked across threads |
| 48 | + int i_data = temp_nlist[block_offset + ii]; |
| 49 | + int o_data = i_data == -1 ? 0 : 1; |
| 50 | + |
| 51 | + // Collectively compute the block-wide exclusive prefix sum |
| 52 | + BlockScan(temp_storage).ExclusiveSum( |
| 53 | + o_data, o_data, prefix_op); |
| 54 | + |
| 55 | + __syncthreads(); |
| 56 | + // Store scanned items to output segment |
| 57 | + if (i_data != -1) { |
| 58 | + nei_order[block_offset + ii] = o_data; |
| 59 | + } |
| 60 | + // Store numneigh into the output array |
| 61 | + if (ii == nall - 1) { |
| 62 | + o_data += i_data == -1 ? 0 : 1; |
| 63 | + numneigh[blockIdx.x] = o_data; |
| 64 | + } |
| 65 | + } |
| 66 | +} |
| 67 | + |
5 | 68 | template<typename FPTYPE> |
6 | 69 | __device__ inline FPTYPE dev_dot( |
7 | 70 | FPTYPE * arr1, |
@@ -45,29 +108,6 @@ __global__ void build_nlist( |
45 | 108 | } |
46 | 109 | } |
47 | 110 |
|
48 | | -__global__ void scan_nlist( |
49 | | - int * numneigh, |
50 | | - int * nei_order, |
51 | | - const int * temp_nlist, |
52 | | - const int mem_size, |
53 | | - const int nloc, |
54 | | - const int nall) |
55 | | -{ |
56 | | - const unsigned int atom_idx = blockIdx.x * blockDim.x + threadIdx.x; |
57 | | - if(atom_idx<nloc){ |
58 | | - const int * row_nlist = temp_nlist + atom_idx * mem_size; |
59 | | - int * row_order = nei_order + atom_idx * mem_size; |
60 | | - int nei_num=0; |
61 | | - for(int i=0;i<nall;i++){ |
62 | | - if(row_nlist[i]!=-1){ |
63 | | - row_order[i]=nei_num; |
64 | | - nei_num++; |
65 | | - } |
66 | | - } |
67 | | - numneigh[atom_idx]=nei_num; |
68 | | - } |
69 | | -} |
70 | | - |
71 | 111 | __global__ void fill_nlist( |
72 | 112 | int ** firstneigh, |
73 | 113 | const int * temp_nlist, |
@@ -143,14 +183,10 @@ int build_nlist_gpu_rocm( |
143 | 183 | mem_size); |
144 | 184 | DPErrcheck(hipGetLastError()); |
145 | 185 | DPErrcheck(hipDeviceSynchronize()); |
146 | | - const int nblock_ = (nloc+TPB-1)/TPB; |
147 | | - hipLaunchKernelGGL(scan_nlist, nblock_, TPB, 0, 0, |
148 | | - numneigh, |
149 | | - nei_order, |
150 | | - temp_nlist, |
151 | | - mem_size, |
152 | | - nloc, |
153 | | - nall); |
| 186 | + hipLaunchKernelGGL( |
| 187 | + HIP_KERNEL_NAME(parallel_prefix_scan<TPB>), nloc, TPB, 0, 0, |
| 188 | + numneigh, nei_order, |
| 189 | + temp_nlist, mem_size, nloc, nall); |
154 | 190 | DPErrcheck(hipGetLastError()); |
155 | 191 | DPErrcheck(hipDeviceSynchronize()); |
156 | 192 | hipLaunchKernelGGL(fill_nlist, block_grid, thread_grid, 0, 0, |
|
0 commit comments