55#include " tensorflow/core/framework/op.h"
66#include " tensorflow/core/framework/op_kernel.h"
77#include " tensorflow/core/framework/shape_inference.h"
8-
8+ # include " NNPAtomMap.h "
99#include < vector>
1010#include " version.h"
1111
12+ typedef double compute_t ;
1213using namespace tensorflow ;
1314using namespace std ;
1415
@@ -53,6 +54,7 @@ class NNPInter
5354{
5455public:
5556 NNPInter () ;
57+ ~NNPInter () ;
5658 NNPInter (const string & model, const int & gpu_rank = 0 );
5759 void init (const string & model, const int & gpu_rank = 0 );
5860 void print_summary (const string &pre ) const ;
@@ -74,6 +76,7 @@ class NNPInter
7476 const vector<VALUETYPE> & box,
7577 const int nghost,
7678 const LammpsNeighborList & lmp_list,
79+ const int & ago,
7780 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
7881 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
7982 void compute (ENERGYTYPE & ener,
@@ -96,6 +99,7 @@ class NNPInter
9699 const vector<VALUETYPE> & box,
97100 const int nghost,
98101 const LammpsNeighborList & lmp_list,
102+ const int & ago,
99103 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
100104 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
101105 VALUETYPE cutoff () const {assert (inited); return rcut;};
@@ -118,12 +122,27 @@ class NNPInter
118122 void validate_fparam_aparam (const int & nloc,
119123 const vector<VALUETYPE> &fparam,
120124 const vector<VALUETYPE> &aparam)const ;
125+
126+ // copy neighbor list info from host
127+ bool init_nbor;
128+ std::vector<int > sec_a;
129+ compute_t *array_double;
130+ NNPAtomMap<VALUETYPE> nnpmap;
131+ unsigned long long *array_longlong;
132+ int *ilist, *jrange, *jlist, *array_int;
133+ int ilist_size, jrange_size, jlist_size;
134+ int arr_int_size, arr_ll_size, arr_dou_size;
135+
136+ // function used for neighbor list copy
137+ vector<int > get_sel_a () const ;
138+ void update_nbor (const InternalNeighborList & nlist, const int nloc);
121139};
122140
123141class NNPInterModelDevi
124142{
125143public:
126144 NNPInterModelDevi () ;
145+ ~NNPInterModelDevi () ;
127146 NNPInterModelDevi (const vector<string> & models, const int & gpu_rank = 0 );
128147 void init (const vector<string> & models, const int & gpu_rank = 0 );
129148public:
@@ -144,6 +163,7 @@ class NNPInterModelDevi
144163 const vector<VALUETYPE> & box,
145164 const int nghost,
146165 const LammpsNeighborList & lmp_list,
166+ const int & ago,
147167 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
148168 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
149169 void compute (vector<ENERGYTYPE> & all_ener,
@@ -156,6 +176,7 @@ class NNPInterModelDevi
156176 const vector<VALUETYPE> & box,
157177 const int nghost,
158178 const LammpsNeighborList & lmp_list,
179+ const int & ago,
159180 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
160181 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
161182 VALUETYPE cutoff () const {assert (inited); return rcut;};
@@ -193,6 +214,22 @@ class NNPInterModelDevi
193214 void validate_fparam_aparam (const int & nloc,
194215 const vector<VALUETYPE> &fparam,
195216 const vector<VALUETYPE> &aparam)const ;
217+
218+ // copy neighbor list info from host
219+ bool init_nbor;
220+ vector<vector<int > > sec;
221+ compute_t *array_double;
222+ NNPAtomMap<VALUETYPE> nnpmap;
223+ unsigned long long *array_longlong;
224+ int max_sec_size = 0 , max_sec_back = 0 ;
225+ int *ilist, *jrange, *jlist, *array_int;
226+ int ilist_size, jrange_size, jlist_size, arr_int_size, arr_ll_size, arr_dou_size;
227+
228+ // function used for nborlist copy
229+ void get_max_sec ();
230+ vector<vector<int > > get_sel () const ;
231+ void cum_sum (const std::vector<std::vector<int32> > n_sel);
232+ void update_nbor (const InternalNeighborList & nlist, const int nloc);
196233};
197234
198235
0 commit comments