|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "NNPInter.h" |
| 4 | + |
| 5 | +class DeepTensor |
| 6 | +{ |
| 7 | +public: |
| 8 | + DeepTensor(); |
| 9 | + DeepTensor(const string & model, |
| 10 | + const int & gpu_rank = 0, |
| 11 | + const string &name_scope = ""); |
| 12 | + void init (const string & model, |
| 13 | + const int & gpu_rank = 0, |
| 14 | + const string &name_scope = ""); |
| 15 | + void print_summary(const string &pre) const; |
| 16 | +public: |
| 17 | + void compute (vector<VALUETYPE> & value, |
| 18 | + const vector<VALUETYPE> & coord, |
| 19 | + const vector<int> & atype, |
| 20 | + const vector<VALUETYPE> & box, |
| 21 | + const int nghost = 0); |
| 22 | + void compute (vector<VALUETYPE> & value, |
| 23 | + const vector<VALUETYPE> & coord, |
| 24 | + const vector<int> & atype, |
| 25 | + const vector<VALUETYPE> & box, |
| 26 | + const int nghost, |
| 27 | + const LammpsNeighborList & lmp_list); |
| 28 | + VALUETYPE cutoff () const {assert(inited); return rcut;}; |
| 29 | + int numb_types () const {assert(inited); return ntypes;}; |
| 30 | + int output_dim () const {assert(inited); return odim;}; |
| 31 | + const vector<int> & sel_types () const {assert(inited); return sel_type;}; |
| 32 | +private: |
| 33 | + Session* session; |
| 34 | + string name_scope; |
| 35 | + int num_intra_nthreads, num_inter_nthreads; |
| 36 | + GraphDef graph_def; |
| 37 | + bool inited; |
| 38 | + VALUETYPE rcut; |
| 39 | + VALUETYPE cell_size; |
| 40 | + int ntypes; |
| 41 | + string model_type; |
| 42 | + int odim; |
| 43 | + vector<int> sel_type; |
| 44 | + template<class VT> VT get_scalar(const string & name) const; |
| 45 | + template<class VT> void get_vector (vector<VT> & vec, const string & name) const; |
| 46 | + void run_model (vector<VALUETYPE> & d_tensor_, |
| 47 | + Session * session, |
| 48 | + const std::vector<std::pair<string, Tensor>> & input_tensors, |
| 49 | + const NNPAtomMap<VALUETYPE> & nnpmap, |
| 50 | + const int nghost = 0); |
| 51 | + void compute_inner (vector<VALUETYPE> & value, |
| 52 | + const vector<VALUETYPE> & coord, |
| 53 | + const vector<int> & atype, |
| 54 | + const vector<VALUETYPE> & box, |
| 55 | + const int nghost = 0); |
| 56 | + void compute_inner (vector<VALUETYPE> & value, |
| 57 | + const vector<VALUETYPE> & coord, |
| 58 | + const vector<int> & atype, |
| 59 | + const vector<VALUETYPE> & box, |
| 60 | + const int nghost, |
| 61 | + const InternalNeighborList&lmp_list); |
| 62 | +}; |
| 63 | + |
0 commit comments