Skip to content

Commit bf9ba83

Browse files
LuLu
authored andcommitted
fix prod_force_se_a GPU kernel error of wrong output
1 parent be6ca64 commit bf9ba83

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

source/op/cuda/prod_force_se_a.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,13 @@ __global__ void deriv_wrt_center_atom_se_a(VALUETYPE * force,
3737
const VALUETYPE * in_deriv,
3838
const int ndescrpt)
3939
{
40-
const unsigned int idx = blockIdx.x;
41-
const unsigned int idy = threadIdx.x;
42-
const unsigned int idz = blockIdx.y;
40+
const unsigned int idx = blockIdx.y;
41+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
42+
const unsigned int idz = threadIdx.y;
43+
44+
if (idy >= ndescrpt) {
45+
return;
46+
}
4347

4448
atomicAdd(force + idx * 3 + idz, -1.0 * net_deriv[idx * ndescrpt + idy] * in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
4549
}
@@ -84,7 +88,6 @@ void ProdForceSeALauncher(VALUETYPE * force,
8488
{
8589
// std::cout << "I'm here!" << std::endl;
8690
cudaErrcheck(cudaMemset(force, 0.0, sizeof(VALUETYPE) * nall * 3));
87-
dim3 grid(nloc, 3);
8891
const int LEN1 = 256;
8992
const int nblock1 = (ndescrpt + LEN1 -1) / LEN1;
9093
dim3 grid(nblock1, nloc);

source/op/cuda/prod_force_se_r.cu

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ __global__ void deriv_wrt_center_atom_se_r(VALUETYPE * force,
3636
const VALUETYPE * in_deriv,
3737
const int ndescrpt)
3838
{
39-
const unsigned int idx = blockIdx.x;
40-
const unsigned int idy = threadIdx.x;
41-
const unsigned int idz = blockIdx.y;
39+
const unsigned int idx = blockIdx.y;
40+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
41+
const unsigned int idz = threadIdx.y;
4242

43+
if (idy >= ndescrpt) {
44+
return;
45+
}
46+
4347
atomicAdd(force + idx * 3 + idz, -1.0 * net_deriv[idx * ndescrpt + idy] * in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
4448
}
4549

@@ -81,8 +85,11 @@ void ProdForceSeRLauncher(VALUETYPE * force,
8185
const int n_a_shift)
8286
{
8387
cudaErrcheck(cudaMemset(force, 0.0, sizeof(VALUETYPE) * nall * 3));
84-
dim3 grid(nloc, 3);
85-
deriv_wrt_center_atom_se_r<<<grid, ndescrpt>>>(force, net_deriv, in_deriv, ndescrpt);
88+
const int LEN1 = 256;
89+
const int nblock1 = (ndescrpt + LEN1 -1) / LEN1;
90+
dim3 grid(nblock1, nloc);
91+
dim3 thread(LEN1, 3);
92+
deriv_wrt_center_atom_se_r<<<grid, thread>>>(force, net_deriv, in_deriv, ndescrpt);
8693

8794
const int LEN = 64;
8895
int nblock = (nloc + LEN -1) / LEN;

0 commit comments

Comments
 (0)