55#include " tensorflow/core/framework/op.h"
66#include " tensorflow/core/framework/op_kernel.h"
77#include " tensorflow/core/framework/shape_inference.h"
8-
8+ # include " NNPAtomMap.h "
99#include < vector>
1010#include " version.h"
1111
12+ typedef double compute_t ;
1213using namespace tensorflow ;
1314using namespace std ;
1415
@@ -53,6 +54,7 @@ class NNPInter
5354{
5455public:
5556 NNPInter () ;
57+ ~NNPInter () ;
5658 NNPInter (const string & model, const int & gpu_rank = 0 );
5759 void init (const string & model, const int & gpu_rank = 0 );
5860 void print_summary (const string &pre ) const ;
@@ -74,6 +76,7 @@ class NNPInter
7476 const vector<VALUETYPE> & box,
7577 const int nghost,
7678 const LammpsNeighborList & lmp_list,
79+ const int & ago,
7780 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
7881 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
7982 void compute (ENERGYTYPE & ener,
@@ -96,6 +99,7 @@ class NNPInter
9699 const vector<VALUETYPE> & box,
97100 const int nghost,
98101 const LammpsNeighborList & lmp_list,
102+ const int & ago,
99103 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
100104 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
101105 VALUETYPE cutoff () const {assert (inited); return rcut;};
@@ -118,12 +122,30 @@ class NNPInter
118122 void validate_fparam_aparam (const int & nloc,
119123 const vector<VALUETYPE> &fparam,
120124 const vector<VALUETYPE> &aparam)const ;
125+
126+ // copy neighbor list info from host
127+ bool init_nbor;
128+ std::vector<int > sec_a;
129+ compute_t *array_double;
130+ InternalNeighborList nlist;
131+ NNPAtomMap<VALUETYPE> nnpmap;
132+ unsigned long long *array_longlong;
133+ int *ilist, *jrange, *jlist, *array_int;
134+ int ilist_size, jrange_size, jlist_size;
135+ int arr_int_size, arr_ll_size, arr_dou_size;
136+
137+ // function used for neighbor list copy
138+ vector<int > get_sel_a () const ;
139+ #ifdef USE_CUDA_TOOLKIT
140+ void update_nbor (const InternalNeighborList & nlist, const int nloc);
141+ #endif
121142};
122143
123144class NNPInterModelDevi
124145{
125146public:
126147 NNPInterModelDevi () ;
148+ ~NNPInterModelDevi () ;
127149 NNPInterModelDevi (const vector<string> & models, const int & gpu_rank = 0 );
128150 void init (const vector<string> & models, const int & gpu_rank = 0 );
129151public:
@@ -144,6 +166,7 @@ class NNPInterModelDevi
144166 const vector<VALUETYPE> & box,
145167 const int nghost,
146168 const LammpsNeighborList & lmp_list,
169+ const int & ago,
147170 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
148171 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
149172 void compute (vector<ENERGYTYPE> & all_ener,
@@ -156,6 +179,7 @@ class NNPInterModelDevi
156179 const vector<VALUETYPE> & box,
157180 const int nghost,
158181 const LammpsNeighborList & lmp_list,
182+ const int & ago,
159183 const vector<VALUETYPE> & fparam = vector<VALUETYPE>(),
160184 const vector<VALUETYPE> & aparam = vector<VALUETYPE>());
161185 VALUETYPE cutoff () const {assert (inited); return rcut;};
@@ -193,6 +217,25 @@ class NNPInterModelDevi
193217 void validate_fparam_aparam (const int & nloc,
194218 const vector<VALUETYPE> &fparam,
195219 const vector<VALUETYPE> &aparam)const ;
220+
221+ // copy neighbor list info from host
222+ bool init_nbor;
223+ compute_t *array_double;
224+ vector<vector<int > > sec;
225+ InternalNeighborList nlist;
226+ NNPAtomMap<VALUETYPE> nnpmap;
227+ unsigned long long *array_longlong;
228+ int max_sec_size = 0 , max_sec_back = 0 ;
229+ int *ilist, *jrange, *jlist, *array_int;
230+ int ilist_size, jrange_size, jlist_size, arr_int_size, arr_ll_size, arr_dou_size;
231+
232+ // function used for nborlist copy
233+ void get_max_sec ();
234+ vector<vector<int > > get_sel () const ;
235+ void cum_sum (const std::vector<std::vector<int32> > n_sel);
236+ #ifdef USE_CUDA_TOOLKIT
237+ void update_nbor (const InternalNeighborList & nlist, const int nloc);
238+ #endif
196239};
197240
198241
0 commit comments