Skip to content

Commit 69b9093

Browse files
authored
Merge pull request #161 from denghuilu/devel-submit
fix bug of zero output when sec_a.back() is lager than 1024
2 parents 2470705 + fb79cbe commit 69b9093

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

source/op/cuda/descrpt_se_a.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,15 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
194194
const VALUETYPE* coord,
195195
const VALUETYPE rmin,
196196
const VALUETYPE rmax,
197-
compute_t* sel_a_diff_dev)
197+
compute_t* sel_a_diff_dev,
198+
const int sec_a_size)
198199
{
199200
// <<<nloc, sec_a.back()>>>
200-
const unsigned int idx = blockIdx.x;
201-
const unsigned int idy = threadIdx.x;
201+
const unsigned int idx = blockIdx.y;
202+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
202203
const int idx_deriv = idy * 4 * 3; // 4 components time 3 directions
203204
const int idx_value = idy * 4; // 4 components
205+
if (idy >= sec_a_size) {return;}
204206

205207
// else {return;}
206208
VALUETYPE * row_descript = descript + idx * ndescrpt;
@@ -341,7 +343,9 @@ void DescrptSeALauncher(const VALUETYPE* coord,
341343
);
342344
}
343345

344-
compute_descriptor_se_a<<<nloc, sec_a.back()>>> (
346+
const int nblock_ = (sec_a.back() + LEN -1) / LEN;
347+
dim3 block_grid(nblock_, nloc);
348+
compute_descriptor_se_a<<<block_grid, LEN>>> (
345349
descript,
346350
ndescrpt,
347351
descript_deriv,
@@ -356,7 +360,8 @@ void DescrptSeALauncher(const VALUETYPE* coord,
356360
coord,
357361
rcut_r_smth,
358362
rcut_r,
359-
sel_a_diff
363+
sel_a_diff,
364+
sec_a.back()
360365
);
361366
////
362367
// res = cudaFree(sec_a_dev); cudaErrcheck(res);

source/op/cuda/descrpt_se_r.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,16 @@ __global__ void compute_descriptor_se_r (VALUETYPE* descript,
195195
const VALUETYPE* coord,
196196
const VALUETYPE rmin,
197197
const VALUETYPE rmax,
198-
compute_t* sel_diff_dev)
198+
compute_t* sel_diff_dev,
199+
const int sec_size)
199200
{
200201
// <<<nloc, sec.back()>>>
201-
const unsigned int idx = blockIdx.x;
202-
const unsigned int idy = threadIdx.x;
202+
const unsigned int idx = blockIdx.y;
203+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
203204
const int idx_deriv = idy * 3; // 1 components time 3 directions
204205
const int idx_value = idy; // 1 components
205-
206+
if (idy >= sec_size) {return;}
207+
206208
// else {return;}
207209
VALUETYPE * row_descript = descript + idx * ndescrpt;
208210
VALUETYPE * row_descript_deriv = descript_deriv + idx * descript_deriv_size;
@@ -310,7 +312,9 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
310312
nei_iter
311313
);
312314
}
313-
compute_descriptor_se_r<<<nloc, sec.back()>>> (
315+
const int nblock_ = (sec.back() + LEN -1) / LEN;
316+
dim3 block_grid(nblock_, nloc);
317+
compute_descriptor_se_r<<<block_grid, LEN>>> (
314318
descript,
315319
ndescrpt,
316320
descript_deriv,
@@ -325,6 +329,7 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
325329
coord,
326330
rcut_smth,
327331
rcut,
328-
sel_diff
332+
sel_diff,
333+
sec.back()
329334
);
330335
}

0 commit comments

Comments
 (0)