Skip to content

Commit 984ef69

Browse files
authored
Merge pull request #162 from amcadmus/master
fix bugs of gpu kernel threading
2 parents 0f9edb4 + d1fa044 commit 984ef69

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

source/op/cuda/descrpt_se_a.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ __global__ void compute_descriptor_se_a (VALUETYPE* descript,
208208
const VALUETYPE* coord,
209209
const VALUETYPE rmin,
210210
const VALUETYPE rmax,
211-
compute_t* sel_a_diff_dev)
211+
compute_t* sel_a_diff_dev,
212+
const int sec_a_size)
212213
{
213214
// <<<nloc, sec_a.back()>>>
214-
const unsigned int idx = blockIdx.x;
215-
const unsigned int idy = threadIdx.x;
215+
const unsigned int idx = blockIdx.y;
216+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
216217
const int idx_deriv = idy * 4 * 3; // 4 components time 3 directions
217218
const int idx_value = idy * 4; // 4 components
219+
if (idy >= sec_a_size) {return;}
218220

219221
// else {return;}
220222
VALUETYPE * row_descript = descript + idx * ndescrpt;
@@ -355,7 +357,9 @@ void DescrptSeALauncher(const VALUETYPE* coord,
355357
);
356358
}
357359

358-
compute_descriptor_se_a<<<nloc, sec_a.back()>>> (
360+
const int nblock_ = (sec_a.back() + LEN -1) / LEN;
361+
dim3 block_grid(nblock_, nloc);
362+
compute_descriptor_se_a<<<block_grid, LEN>>> (
359363
descript,
360364
ndescrpt,
361365
descript_deriv,
@@ -370,7 +374,8 @@ void DescrptSeALauncher(const VALUETYPE* coord,
370374
coord,
371375
rcut_r_smth,
372376
rcut_r,
373-
sel_a_diff
377+
sel_a_diff,
378+
sec_a.back()
374379
);
375380
////
376381
// res = cudaFree(sec_a_dev); cudaErrcheck(res);

source/op/cuda/descrpt_se_r.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,16 @@ __global__ void compute_descriptor_se_r (VALUETYPE* descript,
209209
const VALUETYPE* coord,
210210
const VALUETYPE rmin,
211211
const VALUETYPE rmax,
212-
compute_t* sel_diff_dev)
212+
compute_t* sel_diff_dev,
213+
const int sec_size)
213214
{
214215
// <<<nloc, sec.back()>>>
215-
const unsigned int idx = blockIdx.x;
216-
const unsigned int idy = threadIdx.x;
216+
const unsigned int idx = blockIdx.y;
217+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
217218
const int idx_deriv = idy * 3; // 1 components time 3 directions
218219
const int idx_value = idy; // 1 components
219-
220+
if (idy >= sec_size) {return;}
221+
220222
// else {return;}
221223
VALUETYPE * row_descript = descript + idx * ndescrpt;
222224
VALUETYPE * row_descript_deriv = descript_deriv + idx * descript_deriv_size;
@@ -324,7 +326,9 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
324326
nei_iter
325327
);
326328
}
327-
compute_descriptor_se_r<<<nloc, sec.back()>>> (
329+
const int nblock_ = (sec.back() + LEN -1) / LEN;
330+
dim3 block_grid(nblock_, nloc);
331+
compute_descriptor_se_r<<<block_grid, LEN>>> (
328332
descript,
329333
ndescrpt,
330334
descript_deriv,
@@ -339,6 +343,7 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
339343
coord,
340344
rcut_smth,
341345
rcut,
342-
sel_diff
346+
sel_diff,
347+
sec.back()
343348
);
344349
}

0 commit comments

Comments
 (0)