Skip to content

Commit ab355d0

Browse files
authored
Merge pull request #139 from denghuilu/devel-up
fix prod_force GPU kernels error of wrong output
2 parents 324c527 + bf9ba83 commit ab355d0

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

source/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ include_directories(${TensorFlow_INCLUDE_DIRS})
170170
if (BUILD_CPP_IF)
171171
set (LIB_DEEPMD "deepmd")
172172
set (LIB_DEEPMD_OP "deepmd_op")
173+
if (USE_CUDA_TOOLKIT)
174+
set (LIB_DEEPMD_OP_CUDA "deepmd_op_cuda")
175+
else()
176+
set (LIB_DEEPMD_OP_CUDA "")
177+
endif()
173178
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 4.9)
174179
set (LIB_DEEPMD_NATIVE "deepmd_native_md")
175180
set (LIB_DEEPMD_IPI "deepmd_ipi")

source/lmp/env.sh.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ TF_RPATH=`echo $TENSORFLOW_LIBRARY_PATH | sed "s/;/ -Wl,-rpath=/g"`
88

99
NNP_INC=" -std=c++11 @PREC_DEF@ @TTM_DEF@ -I$TF_INCLUDE_DIRS -I$DEEPMD_ROOT/include/deepmd "
1010
NNP_PATH=" -L$TF_LIBRARY_PATH -L$DEEPMD_ROOT/lib"
11-
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD@ -ldeepmd_op_cuda -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"
11+
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD_OP_CUDA@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"

source/op/cuda/prod_force_se_a.cu

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,13 @@ __global__ void deriv_wrt_center_atom_se_a(VALUETYPE * force,
3737
const VALUETYPE * in_deriv,
3838
const int ndescrpt)
3939
{
40-
const unsigned int idx = blockIdx.x;
41-
const unsigned int idy = threadIdx.x;
42-
const unsigned int idz = blockIdx.y;
40+
const unsigned int idx = blockIdx.y;
41+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
42+
const unsigned int idz = threadIdx.y;
43+
44+
if (idy >= ndescrpt) {
45+
return;
46+
}
4347

4448
atomicAdd(force + idx * 3 + idz, -1.0 * net_deriv[idx * ndescrpt + idy] * in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
4549
}
@@ -84,8 +88,11 @@ void ProdForceSeALauncher(VALUETYPE * force,
8488
{
8589
// std::cout << "I'm here!" << std::endl;
8690
cudaErrcheck(cudaMemset(force, 0.0, sizeof(VALUETYPE) * nall * 3));
87-
dim3 grid(nloc, 3);
88-
deriv_wrt_center_atom_se_a<<<grid, ndescrpt>>>(force, net_deriv, in_deriv, ndescrpt);
91+
const int LEN1 = 256;
92+
const int nblock1 = (ndescrpt + LEN1 -1) / LEN1;
93+
dim3 grid(nblock1, nloc);
94+
dim3 thread(LEN1, 3);
95+
deriv_wrt_center_atom_se_a<<<grid, thread>>>(force, net_deriv, in_deriv, ndescrpt);
8996

9097
const int LEN = 64;
9198
int nblock = (nloc + LEN -1) / LEN;

source/op/cuda/prod_force_se_r.cu

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ __global__ void deriv_wrt_center_atom_se_r(VALUETYPE * force,
3636
const VALUETYPE * in_deriv,
3737
const int ndescrpt)
3838
{
39-
const unsigned int idx = blockIdx.x;
40-
const unsigned int idy = threadIdx.x;
41-
const unsigned int idz = blockIdx.y;
39+
const unsigned int idx = blockIdx.y;
40+
const unsigned int idy = blockIdx.x * blockDim.x + threadIdx.x;
41+
const unsigned int idz = threadIdx.y;
4242

43+
if (idy >= ndescrpt) {
44+
return;
45+
}
46+
4347
atomicAdd(force + idx * 3 + idz, -1.0 * net_deriv[idx * ndescrpt + idy] * in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
4448
}
4549

@@ -81,8 +85,11 @@ void ProdForceSeRLauncher(VALUETYPE * force,
8185
const int n_a_shift)
8286
{
8387
cudaErrcheck(cudaMemset(force, 0.0, sizeof(VALUETYPE) * nall * 3));
84-
dim3 grid(nloc, 3);
85-
deriv_wrt_center_atom_se_r<<<grid, ndescrpt>>>(force, net_deriv, in_deriv, ndescrpt);
88+
const int LEN1 = 256;
89+
const int nblock1 = (ndescrpt + LEN1 -1) / LEN1;
90+
dim3 grid(nblock1, nloc);
91+
dim3 thread(LEN1, 3);
92+
deriv_wrt_center_atom_se_r<<<grid, thread>>>(force, net_deriv, in_deriv, ndescrpt);
8693

8794
const int LEN = 64;
8895
int nblock = (nloc + LEN -1) / LEN;

0 commit comments

Comments
 (0)