@@ -84,9 +84,9 @@ __global__ void format_nlist_fill_a_se_a(const FPTYPE * coord,
8484 const float rcut,
8585 int_64 * key,
8686 int * i_idx,
87- const int MAGIC_NUMBER )
87+ const int MAX_NBOR_SIZE )
8888{
89- // <<<nloc, MAGIC_NUMBER >>>
89+ // <<<nloc, MAX_NBOR_SIZE >>>
9090 const unsigned int idx = blockIdx .x ;
9191 const unsigned int idy = blockIdx .y * blockDim .y + threadIdx .y ;
9292
@@ -98,7 +98,7 @@ __global__ void format_nlist_fill_a_se_a(const FPTYPE * coord,
9898 const int * nei_idx = jlist + jrange[i_idx[idx]];
9999 // dev_copy(nei_idx, &jlist[jrange[i_idx]], nsize);
100100
101- int_64 * key_in = key + idx * MAGIC_NUMBER ;
101+ int_64 * key_in = key + idx * MAX_NBOR_SIZE ;
102102
103103 FPTYPE diff[3 ];
104104 const int & j_idx = nei_idx[idy];
@@ -121,7 +121,7 @@ __global__ void format_nlist_fill_b_se_a(int * nlist,
121121 const int * sec_a,
122122 const int sec_a_size,
123123 int * nei_iter_dev,
124- const int MAGIC_NUMBER )
124+ const int MAX_NBOR_SIZE )
125125{
126126
127127 const unsigned int idy = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -132,13 +132,13 @@ __global__ void format_nlist_fill_b_se_a(int * nlist,
132132
133133 int * row_nlist = nlist + idy * nlist_size;
134134 int * nei_iter = nei_iter_dev + idy * sec_a_size;
135- int_64 * key_out = key + nloc * MAGIC_NUMBER + idy * MAGIC_NUMBER ;
135+ int_64 * key_out = key + nloc * MAX_NBOR_SIZE + idy * MAX_NBOR_SIZE ;
136136
137137 for (int ii = 0 ; ii < sec_a_size; ii++) {
138138 nei_iter[ii] = sec_a[ii];
139139 }
140140
141- for (unsigned int kk = 0 ; key_out[kk] != key_out[MAGIC_NUMBER - 1 ]; kk++) {
141+ for (unsigned int kk = 0 ; key_out[kk] != key_out[MAX_NBOR_SIZE - 1 ]; kk++) {
142142 const int & nei_type = key_out[kk] / 1E15 ;
143143 if (nei_iter[nei_type] < sec_a[nei_type + 1 ]) {
144144 row_nlist[nei_iter[nei_type]++] = key_out[kk] % 100000 ;
@@ -228,73 +228,6 @@ __global__ void compute_descriptor_se_a (FPTYPE* descript,
228228 }
229229}
230230
231- template <typename FPTYPE>
232- void format_nbor_list_256 (
233- const FPTYPE* coord,
234- const int * type,
235- const int * jrange,
236- const int * jlist,
237- const int & nloc,
238- const float & rcut_r,
239- int * i_idx,
240- int_64 * key
241- )
242- {
243- const int LEN = 256 ;
244- const int MAGIC_NUMBER = 256 ;
245- const int nblock = (MAGIC_NUMBER + LEN - 1 ) / LEN;
246- dim3 block_grid (nloc, nblock);
247- dim3 thread_grid (1 , LEN);
248- format_nlist_fill_a_se_a
249- <<<block_grid, thread_grid>>> (
250- coord,
251- type,
252- jrange,
253- jlist,
254- rcut_r,
255- key,
256- i_idx,
257- MAGIC_NUMBER
258- );
259- const int ITEMS_PER_THREAD = 4 ;
260- const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
261- // BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
262- BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER);
263- }
264-
265- template <typename FPTYPE>
266- void format_nbor_list_512 (
267- const FPTYPE* coord,
268- const int * type,
269- const int * jrange,
270- const int * jlist,
271- const int & nloc,
272- const float & rcut_r,
273- int * i_idx,
274- int_64 * key
275- )
276- {
277- const int LEN = 256 ;
278- const int MAGIC_NUMBER = 512 ;
279- const int nblock = (MAGIC_NUMBER + LEN - 1 ) / LEN;
280- dim3 block_grid (nloc, nblock);
281- dim3 thread_grid (1 , LEN);
282- format_nlist_fill_a_se_a
283- <<<block_grid, thread_grid>>> (
284- coord,
285- type,
286- jrange,
287- jlist,
288- rcut_r,
289- key,
290- i_idx,
291- MAGIC_NUMBER
292- );
293- const int ITEMS_PER_THREAD = 4 ;
294- const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
295- // BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
296- BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER);
297- }
298231
299232template <typename FPTYPE>
300233void format_nbor_list_1024 (
@@ -309,8 +242,8 @@ void format_nbor_list_1024 (
309242)
310243{
311244 const int LEN = 256 ;
312- const int MAGIC_NUMBER = 1024 ;
313- const int nblock = (MAGIC_NUMBER + LEN - 1 ) / LEN;
245+ const int MAX_NBOR_SIZE = 1024 ;
246+ const int nblock = (MAX_NBOR_SIZE + LEN - 1 ) / LEN;
314247 dim3 block_grid (nloc, nblock);
315248 dim3 thread_grid (1 , LEN);
316249 format_nlist_fill_a_se_a
@@ -322,12 +255,12 @@ void format_nbor_list_1024 (
322255 rcut_r,
323256 key,
324257 i_idx,
325- MAGIC_NUMBER
258+ MAX_NBOR_SIZE
326259 );
327260 const int ITEMS_PER_THREAD = 8 ;
328- const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
261+ const int BLOCK_THREADS = MAX_NBOR_SIZE / ITEMS_PER_THREAD;
329262 // BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
330- BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER );
263+ BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAX_NBOR_SIZE );
331264}
332265
333266template <typename FPTYPE>
@@ -343,8 +276,8 @@ void format_nbor_list_2048 (
343276)
344277{
345278 const int LEN = 256 ;
346- const int MAGIC_NUMBER = 2048 ;
347- const int nblock = (MAGIC_NUMBER + LEN - 1 ) / LEN;
279+ const int MAX_NBOR_SIZE = 2048 ;
280+ const int nblock = (MAX_NBOR_SIZE + LEN - 1 ) / LEN;
348281 dim3 block_grid (nloc, nblock);
349282 dim3 thread_grid (1 , LEN);
350283 format_nlist_fill_a_se_a
@@ -356,12 +289,12 @@ void format_nbor_list_2048 (
356289 rcut_r,
357290 key,
358291 i_idx,
359- MAGIC_NUMBER
292+ MAX_NBOR_SIZE
360293 );
361294 const int ITEMS_PER_THREAD = 8 ;
362- const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
295+ const int BLOCK_THREADS = MAX_NBOR_SIZE / ITEMS_PER_THREAD;
363296 // BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
364- BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER );
297+ BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAX_NBOR_SIZE );
365298}
366299
367300template <typename FPTYPE>
@@ -377,8 +310,8 @@ void format_nbor_list_4096 (
377310)
378311{
379312 const int LEN = 256 ;
380- const int MAGIC_NUMBER = 4096 ;
381- const int nblock = (MAGIC_NUMBER + LEN - 1 ) / LEN;
313+ const int MAX_NBOR_SIZE = 4096 ;
314+ const int nblock = (MAX_NBOR_SIZE + LEN - 1 ) / LEN;
382315 dim3 block_grid (nloc, nblock);
383316 dim3 thread_grid (1 , LEN);
384317 format_nlist_fill_a_se_a
@@ -390,16 +323,16 @@ void format_nbor_list_4096 (
390323 rcut_r,
391324 key,
392325 i_idx,
393- MAGIC_NUMBER
326+ MAX_NBOR_SIZE
394327 );
395328 const int ITEMS_PER_THREAD = 16 ;
396- const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
329+ const int BLOCK_THREADS = MAX_NBOR_SIZE / ITEMS_PER_THREAD;
397330 // BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
398- BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER );
331+ BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAX_NBOR_SIZE );
399332}
400333
401334template <typename FPTYPE>
402- void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator ()(const FPTYPE * coord, const int * type, const int * ilist, const int * jrange, const int * jlist, int * array_int, unsigned long long * array_longlong, const FPTYPE * avg, const FPTYPE * std, FPTYPE * descript, FPTYPE * descript_deriv, FPTYPE * rij, int * nlist, const int nloc, const int nall, const int nnei, const int ndescrpt, const float rcut_r, const float rcut_r_smth, const std::vector<int > sec_a, const bool fill_nei_a, const int MAGIC_NUMBER ) {
335+ void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator ()(const FPTYPE * coord, const int * type, const int * ilist, const int * jrange, const int * jlist, int * array_int, unsigned long long * array_longlong, const FPTYPE * avg, const FPTYPE * std, FPTYPE * descript, FPTYPE * descript_deriv, FPTYPE * rij, int * nlist, const int nloc, const int nall, const int nnei, const int ndescrpt, const float rcut_r, const float rcut_r_smth, const std::vector<int > sec_a, const bool fill_nei_a, const int max_nbor_size ) {
403336 const int LEN = 256 ;
404337 int nblock = (nloc + LEN -1 ) / LEN;
405338 int * sec_a_dev = array_int;
@@ -409,7 +342,7 @@ void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator()(const FPTYPE * coord, const
409342
410343 cudaError_t res = cudaSuccess;
411344 res = cudaMemcpy (sec_a_dev, &sec_a[0 ], sizeof (int ) * sec_a.size (), cudaMemcpyHostToDevice); cudaErrcheck (res);
412- res = cudaMemset (key, 0xffffffff , sizeof (int_64) * nloc * MAGIC_NUMBER ); cudaErrcheck (res);
345+ res = cudaMemset (key, 0xffffffff , sizeof (int_64) * nloc * max_nbor_size ); cudaErrcheck (res);
413346 res = cudaMemset (nlist, -1 , sizeof (int ) * nloc * nnei); cudaErrcheck (res);
414347 res = cudaMemset (descript, 0.0 , sizeof (FPTYPE) * nloc * ndescrpt); cudaErrcheck (res);
415348 res = cudaMemset (descript_deriv, 0.0 , sizeof (FPTYPE) * nloc * ndescrpt * 3 ); cudaErrcheck (res);
@@ -419,29 +352,7 @@ void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator()(const FPTYPE * coord, const
419352 // cudaProfilerStart();
420353 get_i_idx_se_a<<<nblock, LEN>>> (nloc, ilist, i_idx);
421354
422- if (nnei <= 256 ) {
423- format_nbor_list_256 (
424- coord,
425- type,
426- jrange,
427- jlist,
428- nloc,
429- rcut_r,
430- i_idx,
431- key
432- );
433- } else if (nnei <= 512 ) {
434- format_nbor_list_512 (
435- coord,
436- type,
437- jrange,
438- jlist,
439- nloc,
440- rcut_r,
441- i_idx,
442- key
443- );
444- } else if (nnei <= 1024 ) {
355+ if (max_nbor_size <= 1024 ) {
445356 format_nbor_list_1024 (
446357 coord,
447358 type,
@@ -452,7 +363,7 @@ void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator()(const FPTYPE * coord, const
452363 i_idx,
453364 key
454365 );
455- } else if (nnei <= 2048 ) {
366+ } else if (max_nbor_size <= 2048 ) {
456367 format_nbor_list_2048 (
457368 coord,
458369 type,
@@ -463,7 +374,7 @@ void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator()(const FPTYPE * coord, const
463374 i_idx,
464375 key
465376 );
466- } else if (nnei <= 4096 ) {
377+ } else if (max_nbor_size <= 4096 ) {
467378 format_nbor_list_4096 (
468379 coord,
469380 type,
@@ -486,7 +397,7 @@ void DescrptSeAGPUExecuteFunctor<FPTYPE>::operator()(const FPTYPE * coord, const
486397 sec_a_dev,
487398 sec_a.size (),
488399 nei_iter,
489- MAGIC_NUMBER
400+ max_nbor_size
490401 );
491402 }
492403
0 commit comments