Skip to content

Commit 5bf6f16

Browse files
authored
Merge pull request #17 from deepmodeling/devel
devel update
2 parents 28dc891 + 60634f9 commit 5bf6f16

File tree

15 files changed

+500
-33
lines changed

15 files changed

+500
-33
lines changed

source/CMakeLists.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,6 @@ include_directories(${TensorFlow_INCLUDE_DIRS})
184184
if (BUILD_CPP_IF)
185185
set (LIB_DEEPMD "deepmd")
186186
set (LIB_DEEPMD_OP "deepmd_op")
187-
if (USE_CUDA_TOOLKIT)
188-
set (LIB_DEEPMD_OP_CUDA "deepmd_op_cuda")
189-
else()
190-
set (LIB_DEEPMD_OP_CUDA "deepmd_op")
191-
endif()
192187
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 4.9)
193188
set (LIB_DEEPMD_NATIVE "deepmd_native_md")
194189
set (LIB_DEEPMD_IPI "deepmd_ipi")

source/cmake/Findtensorflow.cmake

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
# TensorFlowFramework_LIBRARY
1111
# TensorFlowFramework_LIBRARY_PATH
1212

13+
string(REPLACE "lib64" "lib" TENSORFLOW_ROOT_NO64 ${TENSORFLOW_ROOT})
14+
1315
# define the search path
1416
list(APPEND TensorFlow_search_PATHS ${TENSORFLOW_ROOT})
1517
list(APPEND TensorFlow_search_PATHS "${TENSORFLOW_ROOT}/../tensorflow_core")
18+
list(APPEND TensorFlow_search_PATHS ${TENSORFLOW_ROOT_NO64})
19+
list(APPEND TensorFlow_search_PATHS "${TENSORFLOW_ROOT_NO64}/../tensorflow_core")
1620
list(APPEND TensorFlow_search_PATHS "/usr/")
1721
list(APPEND TensorFlow_search_PATHS "/usr/local/")
1822

@@ -28,9 +32,18 @@ find_path(TensorFlow_INCLUDE_DIRS
2832
PATH_SUFFIXES "/include"
2933
NO_DEFAULT_PATH
3034
)
35+
find_path(TensorFlow_INCLUDE_DIRS_GOOGLE
36+
NAMES
37+
google/protobuf/type.pb.h
38+
PATHS ${TensorFlow_search_PATHS}
39+
PATH_SUFFIXES "/include"
40+
NO_DEFAULT_PATH
41+
)
42+
list(APPEND TensorFlow_INCLUDE_DIRS ${TensorFlow_INCLUDE_DIRS_GOOGLE})
43+
3144
if (NOT TensorFlow_INCLUDE_DIRS AND tensorflow_FIND_REQUIRED)
3245
message(FATAL_ERROR
33-
"Not found 'include/tensorflow/core/public/session.h' directory in path '${TensorFlow_search_PATHS}' "
46+
"Not found 'tensorflow/core/public/session.h' directory in path '${TensorFlow_search_PATHS}' "
3447
"You can manually set the tensorflow install path by -DTENSORFLOW_ROOT ")
3548
endif ()
3649

source/lib/src/NNPInter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "SimulationRegion.h"
44
#include <stdexcept>
55

6-
#define MAGIC_NUMBER 256
6+
#define MAGIC_NUMBER 1024
77

88
#ifdef USE_CUDA_TOOLKIT
99
#include "cuda_runtime.h"

source/lmp/env.sh.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ TF_RPATH=`echo $TENSORFLOW_LIBRARY_PATH | sed "s/;/ -Wl,-rpath=/g"`
88

99
NNP_INC=" -std=c++11 @PREC_DEF@ @TTM_DEF@ @OLD_LMP_PPPM_DEF@ -I$TF_INCLUDE_DIRS -I$DEEPMD_ROOT/include/deepmd "
1010
NNP_PATH=" -L$TF_LIBRARY_PATH -L$DEEPMD_ROOT/lib"
11-
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD_OP_CUDA@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"
11+
NNP_LIB=" -Wl,--no-as-needed -l@LIB_DEEPMD_OP@ -l@LIB_DEEPMD@ -ltensorflow_cc -ltensorflow_framework -Wl,-rpath=$TF_RPATH -Wl,-rpath=$DEEPMD_ROOT/lib"

source/op/CMakeLists.txt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
set(OP_LIB ${PROJECT_SOURCE_DIR}/lib/src/SimulationRegion.cpp ${PROJECT_SOURCE_DIR}/lib/src/NeighborList.cpp)
44

55
set (OP_CXX_FLAG -D_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI} )
6-
file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc)
7-
file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_gpu.cc descrpt_se_r_gpu.cc tab_inter.cc prod_force_se_a_gpu.cc prod_virial_se_a_gpu.cc prod_force_se_r_gpu.cc prod_virial_se_r_gpu.cc soft_min.cc soft_min_force.cc soft_min_virial.cc )
6+
file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu.cc)
7+
file(GLOB OP_PY_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a.cc descrpt_se_r.cc tab_inter.cc prod_force_se_a.cc prod_virial_se_a.cc prod_force_se_r.cc prod_virial_se_r.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu_gpu.cc)
8+
file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_gpu.cc descrpt_se_r_gpu.cc tab_inter.cc prod_force_se_a_gpu.cc prod_virial_se_a_gpu.cc prod_force_se_r_gpu.cc prod_virial_se_r_gpu.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_gpu.cc)
89
file(GLOB OP_GRADS_SRC prod_force_grad.cc prod_force_se_a_grad.cc prod_force_se_r_grad.cc prod_virial_grad.cc prod_virial_se_a_grad.cc prod_virial_se_r_grad.cc soft_min_force_grad.cc soft_min_virial_grad.cc )
910
file(GLOB OP_PY *.py)
1011

@@ -23,8 +24,20 @@ if (BUILD_CPP_IF)
2324
endif (BUILD_CPP_IF)
2425

2526
if (BUILD_PY_IF)
26-
add_library(op_abi SHARED ${OP_SRC} ${OP_LIB})
27-
add_library(op_grads SHARED ${OP_GRADS_SRC})
27+
if (USE_CUDA_TOOLKIT)
28+
add_library(op_abi SHARED ${OP_PY_CUDA_SRC} ${OP_LIB})
29+
add_library(op_grads SHARED ${OP_GRADS_SRC})
30+
add_subdirectory(cuda)
31+
find_package(CUDA REQUIRED)
32+
include_directories(${CUDA_INCLUDE_DIRS})
33+
set (EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_cuda)
34+
target_link_libraries (op_abi ${EXTRA_LIBS})
35+
target_link_libraries (op_grads ${EXTRA_LIBS})
36+
message(STATUS ${TensorFlowFramework_LIBRARY})
37+
else (USE_CUDA_TOOLKIT)
38+
add_library(op_abi SHARED ${OP_SRC} ${OP_LIB})
39+
add_library(op_grads SHARED ${OP_GRADS_SRC})
40+
endif(USE_CUDA_TOOLKIT)
2841
target_link_libraries(
2942
op_abi ${TensorFlowFramework_LIBRARY}
3043
)

source/op/_gelu.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/usr/bin/env python3
2+
"""
3+
First-order derivatives and second-order derivatives for gelu function.
4+
"""
5+
6+
from tensorflow.python.framework import ops
7+
from deepmd.env import op_module
8+
9+
@ops.RegisterGradient("Gelu")
10+
def gelu_cc (op, dy) :
11+
return op_module.gelu_grad(dy, op.inputs[0])
12+
13+
@ops.RegisterGradient("GeluGrad")
14+
def gelu_grad_cc (op, dy) :
15+
return [None, op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])]

source/op/cuda/CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,14 @@ else ()
8080
endif()
8181

8282
set (SOURCE_FILES
83-
descrpt_se_a.cu descrpt_se_r.cu prod_force_se_a.cu prod_force_se_r.cu prod_virial_se_a.cu prod_virial_se_r.cu
83+
descrpt_se_a.cu descrpt_se_r.cu prod_force_se_a.cu prod_force_se_r.cu prod_virial_se_a.cu prod_virial_se_r.cu gelu.cu
8484
)
8585

86-
cuda_add_library(deepmd_op_cuda SHARED ${SOURCE_FILES})
86+
cuda_add_library(deepmd_op_cuda STATIC ${SOURCE_FILES})
8787

88-
install(TARGETS deepmd_op_cuda DESTINATION lib/)
88+
if (BUILD_CPP_IF)
89+
install(TARGETS deepmd_op_cuda DESTINATION lib/)
90+
endif (BUILD_CPP_IF)
91+
if (BUILD_PY_IF)
92+
install(TARGETS deepmd_op_cuda DESTINATION deepmd/)
93+
endif (BUILD_PY_IF)

source/op/cuda/descrpt_se_a.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818
#include <cub/block/block_radix_sort.cuh>
1919
#include <cuda_runtime.h>
2020

21-
#define MAGIC_NUMBER 256
21+
#define MAGIC_NUMBER 1024
2222

2323
#ifdef HIGH_PREC
2424
typedef double VALUETYPE;
@@ -326,7 +326,7 @@ void DescrptSeALauncher(const VALUETYPE* coord,
326326
i_idx
327327
);
328328
const int ITEMS_PER_THREAD = 4;
329-
const int BLOCK_THREADS = 64;
329+
const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
330330
// BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
331331
BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER);
332332

source/op/cuda/gelu.cu

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include <cuda_runtime.h>
2+
#include <stdio.h>
3+
4+
#define SQRT_2_PI 0.7978845608028654
5+
6+
template <typename T>
7+
__global__ void gelu(const T * in, T * out, int const size) {
8+
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
9+
if (idx >= size) {return;}
10+
11+
out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])));
12+
}
13+
14+
template <typename T>
15+
__global__ void gelu_grad(const T * dy, const T * in, T * out, int const size) {
16+
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
17+
if (idx >= size) {return;}
18+
19+
// out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])));
20+
T const var1 = tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]));
21+
out[idx] = dy[idx] * (0.5 * SQRT_2_PI * in[idx] * (1 - var1 * var1) * (0.134145 * in[idx] * in[idx] + 1) + 0.5 * var1 + 0.5);
22+
}
23+
24+
template <typename T>
25+
__global__ void gelu_grad_grad(const T * dy, const T * dy_, const T * in, T * out, int const size) {
26+
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
27+
if (idx >= size) {return;}
28+
29+
// out[idx] = in[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx])));
30+
T const var1 = tanh(SQRT_2_PI * (in[idx] + 0.044715 * in[idx] * in[idx] *in[idx]));
31+
T const var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * in[idx] * in[idx] + 1);
32+
33+
out[idx] = dy[idx] * dy_[idx] * (0.134145 * SQRT_2_PI * in[idx] * in[idx] * (1 - var1 * var1) - SQRT_2_PI * in[idx] * var2 * (0.134145 * in[idx] * in[idx] + 1) * var1 + var2);
34+
}
35+
36+
37+
void GeluLauncher(const float * in, float * out, int const size) {
38+
int const THREAD_ITEMS = 1024;
39+
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;
40+
41+
gelu<<<BLOCK_NUMS, THREAD_ITEMS>>>(in, out, size);
42+
}
43+
44+
void GeluLauncher(const double * in, double * out, int const size) {
45+
int const THREAD_ITEMS = 1024;
46+
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;
47+
48+
gelu<<<BLOCK_NUMS, THREAD_ITEMS>>>(in, out, size);
49+
}
50+
51+
void GeluGradLauncher(const float * dy, const float * in, float * out, int const size) {
52+
int const THREAD_ITEMS = 1024;
53+
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;
54+
55+
gelu_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, in, out, size);
56+
}
57+
58+
void GeluGradLauncher(const double * dy, const double * in, double * out, int const size) {
59+
int const THREAD_ITEMS = 1024;
60+
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;
61+
62+
gelu_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, in, out, size);
63+
}
64+
65+
void GeluGradGradLauncher(const float * dy, const float * dy_, const float * in, float * out, int const size) {
66+
int const THREAD_ITEMS = 1024;
67+
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;
68+
69+
gelu_grad_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, dy_, in, out, size);
70+
}
71+
72+
void GeluGradGradLauncher(const double * dy, const double * dy_, const double * in, double * out, int const size) {
73+
int const THREAD_ITEMS = 1024;
74+
int const BLOCK_NUMS = (size + THREAD_ITEMS - 1) / THREAD_ITEMS;
75+
76+
gelu_grad_grad<<<BLOCK_NUMS, THREAD_ITEMS>>>(dy, dy_, in, out, size);
77+
}

source/op/descrpt_se_a_gpu.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "tensorflow/core/framework/shape_inference.h"
88

99
using namespace tensorflow; // NOLINT(build/namespaces)
10-
#define MAGIC_NUMBER 256
1110

1211
#ifdef HIGH_PREC
1312
typedef double VALUETYPE ;
@@ -159,7 +158,8 @@ class DescrptSeAOp : public OpKernel {
159158

160159
OP_REQUIRES (context, (ntypes == int(sel_a.size())), errors::InvalidArgument ("number of types should match the length of sel array"));
161160
OP_REQUIRES (context, (ntypes == int(sel_r.size())), errors::InvalidArgument ("number of types should match the length of sel array"));
162-
161+
OP_REQUIRES (context, (nnei <= 1024), errors::InvalidArgument ("Assert failed, max neighbor size of atom(nnei) " + std::to_string(nnei) + " is larger than 1024!, which currently is not supported by deepmd-kit."));
162+
163163
// Create output tensors
164164
TensorShape descrpt_shape ;
165165
descrpt_shape.AddDim (nsamples);
@@ -201,7 +201,6 @@ class DescrptSeAOp : public OpKernel {
201201
cudaErrcheck(cudaMemcpy(&(array_longlong), 20 + mesh_tensor.flat<int>().data(), sizeof(unsigned long long *), cudaMemcpyDeviceToHost));
202202
cudaErrcheck(cudaMemcpy(&(array_double), 24 + mesh_tensor.flat<int>().data(), sizeof(compute_t *), cudaMemcpyDeviceToHost));
203203

204-
// cudaErrcheck(cudaMemcpy(jlist, host_jlist, sizeof(int) * nloc * MAGIC_NUMBER, cudaMemcpyHostToDevice));
205204
// Launch computation
206205
for (int II = 0; II < nsamples; II++) {
207206
DescrptSeALauncher(coord_tensor.matrix<VALUETYPE>().data() + II * (nall * 3), // related to the kk argument

0 commit comments

Comments
 (0)