Skip to content

Commit be6ca64

Browse files
LuLu
authored andcommitted
fix prod_force_se_a GPU kernel error of wrong output
set ndescrpt in grid level
1 parent 2891378 commit be6ca64

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

source/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ include_directories(${TensorFlow_INCLUDE_DIRS})
170170
if (BUILD_CPP_IF)
171171
set (LIB_DEEPMD "deepmd")
172172
set (LIB_DEEPMD_OP "deepmd_op")
173+
if (USE_CUDA_TOOLKIT)
174+
set (LIB_DEEPMD_OP_CUDA "deepmd_op_cuda")
175+
else()
176+
set (LIB_DEEPMD_OP_CUDA "")
177+
endif()
173178
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 4.9)
174179
set (LIB_DEEPMD_NATIVE "deepmd_native_md")
175180
set (LIB_DEEPMD_IPI "deepmd_ipi")

source/lmp/env.sh.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ TF_RPATH=`echo $TENSORFLOW_LIBRARY_PATH | sed "s/;/ -Wl,-rpath=/g"`
88

99
NNP_INC=" -std=c++11 @PREC_DEF@ @TTM_DEF@ -I$TF_INCLUDE_DIRS -I$DEEPMD_ROOT/include/deepmd "
1010
NNP_PATH=" -L$TF_LIBRARY_PATH -L$DEEPMD_ROOT/lib"
11-
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD@ -ldeepmd_op_cuda -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"
11+
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD_OP_CUDA@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"

source/op/cuda/prod_force_se_a.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ void ProdForceSeALauncher(VALUETYPE * force,
8585
// std::cout << "I'm here!" << std::endl;
8686
cudaErrcheck(cudaMemset(force, 0.0, sizeof(VALUETYPE) * nall * 3));
8787
dim3 grid(nloc, 3);
88-
deriv_wrt_center_atom_se_a<<<grid, ndescrpt>>>(force, net_deriv, in_deriv, ndescrpt);
88+
const int LEN1 = 256;
89+
const int nblock1 = (ndescrpt + LEN1 -1) / LEN1;
90+
dim3 grid(nblock1, nloc);
91+
dim3 thread(LEN1, 3);
92+
deriv_wrt_center_atom_se_a<<<grid, thread>>>(force, net_deriv, in_deriv, ndescrpt);
8993

9094
const int LEN = 64;
9195
int nblock = (nloc + LEN -1) / LEN;

0 commit comments

Comments
 (0)