Skip to content

Commit 1045675

Browse files
authored
Merge pull request #120 from denghuilu/devel-up
Add GPU support for tensorflow operations
2 parents 6b5cdc6 + 55a29ba commit 1045675

23 files changed

+2990
-134
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "source/op/cuda/cub"]
2+
path = source/op/cuda/cub
3+
url = git://github.com/NVlabs/cub.git

source/lib/include/NNPAtomMap.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ template <typename VALUETYPE>
88
class NNPAtomMap
99
{
1010
public:
11+
NNPAtomMap();
1112
NNPAtomMap(const vector<int >::const_iterator in_begin,
1213
const vector<int >::const_iterator in_end);
1314
void forward (typename vector<VALUETYPE >::iterator out,

source/lib/include/NNPInter.h

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
#include "tensorflow/core/framework/op.h"
66
#include "tensorflow/core/framework/op_kernel.h"
77
#include "tensorflow/core/framework/shape_inference.h"
8-
8+
#include "NNPAtomMap.h"
99
#include <vector>
1010
#include "version.h"
1111

12+
typedef double compute_t;
1213
using namespace tensorflow;
1314
using namespace std;
1415

@@ -53,6 +54,7 @@ class NNPInter
5354
{
5455
public:
5556
NNPInter () ;
57+
~NNPInter() ;
5658
NNPInter (const string & model, const int & gpu_rank = 0);
5759
void init (const string & model, const int & gpu_rank = 0);
5860
void print_summary(const string &pre) const;
@@ -74,6 +76,7 @@ class NNPInter
7476
const vector<VALUETYPE> & box,
7577
const int nghost,
7678
const LammpsNeighborList & lmp_list,
79+
const int & ago,
7780
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
7881
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
7982
void compute (ENERGYTYPE & ener,
@@ -96,6 +99,7 @@ class NNPInter
9699
const vector<VALUETYPE> & box,
97100
const int nghost,
98101
const LammpsNeighborList & lmp_list,
102+
const int & ago,
99103
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
100104
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
101105
VALUETYPE cutoff () const {assert(inited); return rcut;};
@@ -118,12 +122,30 @@ class NNPInter
118122
void validate_fparam_aparam(const int & nloc,
119123
const vector<VALUETYPE> &fparam,
120124
const vector<VALUETYPE> &aparam)const ;
125+
126+
// copy neighbor list info from host
127+
bool init_nbor;
128+
std::vector<int> sec_a;
129+
compute_t *array_double;
130+
InternalNeighborList nlist;
131+
NNPAtomMap<VALUETYPE> nnpmap;
132+
unsigned long long *array_longlong;
133+
int *ilist, *jrange, *jlist, *array_int;
134+
int ilist_size, jrange_size, jlist_size;
135+
int arr_int_size, arr_ll_size, arr_dou_size;
136+
137+
// function used for neighbor list copy
138+
vector<int> get_sel_a() const;
139+
#ifdef USE_CUDA_TOOLKIT
140+
void update_nbor(const InternalNeighborList & nlist, const int nloc);
141+
#endif
121142
};
122143

123144
class NNPInterModelDevi
124145
{
125146
public:
126147
NNPInterModelDevi () ;
148+
~NNPInterModelDevi() ;
127149
NNPInterModelDevi (const vector<string> & models, const int & gpu_rank = 0);
128150
void init (const vector<string> & models, const int & gpu_rank = 0);
129151
public:
@@ -144,6 +166,7 @@ class NNPInterModelDevi
144166
const vector<VALUETYPE> & box,
145167
const int nghost,
146168
const LammpsNeighborList & lmp_list,
169+
const int & ago,
147170
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
148171
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
149172
void compute (vector<ENERGYTYPE> & all_ener,
@@ -156,6 +179,7 @@ class NNPInterModelDevi
156179
const vector<VALUETYPE> & box,
157180
const int nghost,
158181
const LammpsNeighborList & lmp_list,
182+
const int & ago,
159183
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
160184
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
161185
VALUETYPE cutoff () const {assert(inited); return rcut;};
@@ -193,6 +217,25 @@ class NNPInterModelDevi
193217
void validate_fparam_aparam(const int & nloc,
194218
const vector<VALUETYPE> &fparam,
195219
const vector<VALUETYPE> &aparam)const ;
220+
221+
// copy neighbor list info from host
222+
bool init_nbor;
223+
compute_t *array_double;
224+
vector<vector<int> > sec;
225+
InternalNeighborList nlist;
226+
NNPAtomMap<VALUETYPE> nnpmap;
227+
unsigned long long *array_longlong;
228+
int max_sec_size = 0, max_sec_back = 0;
229+
int *ilist, *jrange, *jlist, *array_int;
230+
int ilist_size, jrange_size, jlist_size, arr_int_size, arr_ll_size, arr_dou_size;
231+
232+
// function used for nborlist copy
233+
void get_max_sec();
234+
vector<vector<int> > get_sel() const;
235+
void cum_sum(const std::vector<std::vector<int32> > n_sel);
236+
#ifdef USE_CUDA_TOOLKIT
237+
void update_nbor(const InternalNeighborList & nlist, const int nloc);
238+
#endif
196239
};
197240

198241

source/lib/src/NNPAtomMap.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include <algorithm>
44
#include <cassert>
55

6+
template <typename VALUETYPE>
7+
NNPAtomMap<VALUETYPE>::
8+
NNPAtomMap() {}
9+
610
template <typename VALUETYPE>
711
NNPAtomMap<VALUETYPE>::
812
NNPAtomMap(const vector<int >::const_iterator in_begin,

0 commit comments

Comments
 (0)