Skip to content

Commit 3d59019

Browse files
LuLu
authored andcommitted
Add GPU support for tensorflow operations
1 parent f9581c5 commit 3d59019

21 files changed

+2920
-146
lines changed

source/lib/include/NNPAtomMap.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ template <typename VALUETYPE>
88
class NNPAtomMap
99
{
1010
public:
11+
NNPAtomMap();
1112
NNPAtomMap(const vector<int >::const_iterator in_begin,
1213
const vector<int >::const_iterator in_end);
1314
void forward (typename vector<VALUETYPE >::iterator out,

source/lib/include/NNPInter.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
#include "tensorflow/core/framework/op.h"
66
#include "tensorflow/core/framework/op_kernel.h"
77
#include "tensorflow/core/framework/shape_inference.h"
8-
8+
#include "NNPAtomMap.h"
99
#include <vector>
1010
#include "version.h"
1111

12+
typedef double compute_t;
1213
using namespace tensorflow;
1314
using namespace std;
1415

@@ -53,6 +54,7 @@ class NNPInter
5354
{
5455
public:
5556
NNPInter () ;
57+
~NNPInter() ;
5658
NNPInter (const string & model, const int & gpu_rank = 0);
5759
void init (const string & model, const int & gpu_rank = 0);
5860
void print_summary(const string &pre) const;
@@ -74,6 +76,7 @@ class NNPInter
7476
const vector<VALUETYPE> & box,
7577
const int nghost,
7678
const LammpsNeighborList & lmp_list,
79+
const int & ago,
7780
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
7881
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
7982
void compute (ENERGYTYPE & ener,
@@ -96,6 +99,7 @@ class NNPInter
9699
const vector<VALUETYPE> & box,
97100
const int nghost,
98101
const LammpsNeighborList & lmp_list,
102+
const int & ago,
99103
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
100104
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
101105
VALUETYPE cutoff () const {assert(inited); return rcut;};
@@ -118,12 +122,27 @@ class NNPInter
118122
void validate_fparam_aparam(const int & nloc,
119123
const vector<VALUETYPE> &fparam,
120124
const vector<VALUETYPE> &aparam)const ;
125+
126+
// copy neighbor list info from host
127+
bool init_nbor;
128+
std::vector<int> sec_a;
129+
compute_t *array_double;
130+
NNPAtomMap<VALUETYPE> nnpmap;
131+
unsigned long long *array_longlong;
132+
int *ilist, *jrange, *jlist, *array_int;
133+
int ilist_size, jrange_size, jlist_size;
134+
int arr_int_size, arr_ll_size, arr_dou_size;
135+
136+
// function used for neighbor list copy
137+
vector<int> get_sel_a() const;
138+
void update_nbor(const InternalNeighborList & nlist, const int nloc);
121139
};
122140

123141
class NNPInterModelDevi
124142
{
125143
public:
126144
NNPInterModelDevi () ;
145+
~NNPInterModelDevi() ;
127146
NNPInterModelDevi (const vector<string> & models, const int & gpu_rank = 0);
128147
void init (const vector<string> & models, const int & gpu_rank = 0);
129148
public:
@@ -144,6 +163,7 @@ class NNPInterModelDevi
144163
const vector<VALUETYPE> & box,
145164
const int nghost,
146165
const LammpsNeighborList & lmp_list,
166+
const int & ago,
147167
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
148168
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
149169
void compute (vector<ENERGYTYPE> & all_ener,
@@ -156,6 +176,7 @@ class NNPInterModelDevi
156176
const vector<VALUETYPE> & box,
157177
const int nghost,
158178
const LammpsNeighborList & lmp_list,
179+
const int & ago,
159180
const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
160181
const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
161182
VALUETYPE cutoff () const {assert(inited); return rcut;};
@@ -193,6 +214,22 @@ class NNPInterModelDevi
193214
void validate_fparam_aparam(const int & nloc,
194215
const vector<VALUETYPE> &fparam,
195216
const vector<VALUETYPE> &aparam)const ;
217+
218+
// copy neighbor list info from host
219+
bool init_nbor;
220+
vector<vector<int> > sec;
221+
compute_t *array_double;
222+
NNPAtomMap<VALUETYPE> nnpmap;
223+
unsigned long long *array_longlong;
224+
int max_sec_size = 0, max_sec_back = 0;
225+
int *ilist, *jrange, *jlist, *array_int;
226+
int ilist_size, jrange_size, jlist_size, arr_int_size, arr_ll_size, arr_dou_size;
227+
228+
// function used for nborlist copy
229+
void get_max_sec();
230+
vector<vector<int> > get_sel() const;
231+
void cum_sum(const std::vector<std::vector<int32> > n_sel);
232+
void update_nbor(const InternalNeighborList & nlist, const int nloc);
196233
};
197234

198235

source/lib/src/NNPAtomMap.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include <algorithm>
44
#include <cassert>
55

6+
template <typename VALUETYPE>
7+
NNPAtomMap<VALUETYPE>::
8+
NNPAtomMap() {}
9+
610
template <typename VALUETYPE>
711
NNPAtomMap<VALUETYPE>::
812
NNPAtomMap(const vector<int >::const_iterator in_begin,

0 commit comments

Comments
 (0)