File tree Expand file tree Collapse file tree 2 files changed +21
-11
lines changed Expand file tree Collapse file tree 2 files changed +21
-11
lines changed Original file line number Diff line number Diff line change @@ -208,13 +208,15 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
208208 const VALUETYPE* coord,
209209 const VALUETYPE rmin,
210210 const VALUETYPE rmax,
211- compute_t * sel_a_diff_dev)
211+ compute_t * sel_a_diff_dev,
212+ const int sec_a_size)
212213{
213214 // <<<nloc, sec_a.back()>>>
214- const unsigned int idx = blockIdx .x ;
215- const unsigned int idy = threadIdx .x ;
215+ const unsigned int idx = blockIdx .y ;
216+ const unsigned int idy = blockIdx . x * blockDim . x + threadIdx .x ;
216217 const int idx_deriv = idy * 4 * 3 ; // 4 components time 3 directions
217218 const int idx_value = idy * 4 ; // 4 components
219+ if (idy >= sec_a_size) {return ;}
218220
219221 // else {return;}
220222 VALUETYPE * row_descript = descript + idx * ndescrpt;
@@ -355,7 +357,9 @@ void DescrptSeALauncher(const VALUETYPE* coord,
355357 );
356358 }
357359
358- compute_descriptor_se_a<<<nloc, sec_a.back()>>> (
360+ const int nblock_ = (sec_a.back () + LEN -1 ) / LEN;
361+ dim3 block_grid (nblock_, nloc);
362+ compute_descriptor_se_a<<<block_grid, LEN>>> (
359363 descript,
360364 ndescrpt,
361365 descript_deriv,
@@ -370,7 +374,8 @@ void DescrptSeALauncher(const VALUETYPE* coord,
370374 coord,
371375 rcut_r_smth,
372376 rcut_r,
373- sel_a_diff
377+ sel_a_diff,
378+ sec_a.back ()
374379 );
375380// //
376381 // res = cudaFree(sec_a_dev); cudaErrcheck(res);
Original file line number Diff line number Diff line change @@ -209,14 +209,16 @@ __global__ void compute_descriptor_se_r (VALUETYPE* descript,
209209 const VALUETYPE* coord,
210210 const VALUETYPE rmin,
211211 const VALUETYPE rmax,
212- compute_t * sel_diff_dev)
212+ compute_t * sel_diff_dev,
213+ const int sec_size)
213214{
214215 // <<<nloc, sec.back()>>>
215- const unsigned int idx = blockIdx .x ;
216- const unsigned int idy = threadIdx .x ;
216+ const unsigned int idx = blockIdx .y ;
217+ const unsigned int idy = blockIdx . x * blockDim . x + threadIdx .x ;
217218 const int idx_deriv = idy * 3 ; // 1 components time 3 directions
218219 const int idx_value = idy; // 1 components
219-
220+ if (idy >= sec_size) {return ;}
221+
220222 // else {return;}
221223 VALUETYPE * row_descript = descript + idx * ndescrpt;
222224 VALUETYPE * row_descript_deriv = descript_deriv + idx * descript_deriv_size;
@@ -324,7 +326,9 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
324326 nei_iter
325327 );
326328 }
327- compute_descriptor_se_r<<<nloc, sec.back()>>> (
329+ const int nblock_ = (sec.back () + LEN -1 ) / LEN;
330+ dim3 block_grid (nblock_, nloc);
331+ compute_descriptor_se_r<<<block_grid, LEN>>> (
328332 descript,
329333 ndescrpt,
330334 descript_deriv,
@@ -339,6 +343,7 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
339343 coord,
340344 rcut_smth,
341345 rcut,
342- sel_diff
346+ sel_diff,
347+ sec.back ()
343348 );
344349}
You can’t perform that action at this time.
0 commit comments