Skip to content

Commit 828df66

Browse files
authored
cc: refactor DeepPotModelDevi, making it framework-independent (#3134)
Refactor `DeepPotModelDevi` as a step of #3122. Now, it is just a wrapper of multiple `DeepPot` classes. Models can have different behaviors inside different `DeepPot`. One may argue that the new class needs to prepare the input multiple times. However, it's not expensive only to copy the memory. Also, during the simulations, usually we run it every 100 steps.
1 parent 04f07ef commit 828df66

File tree

2 files changed

+22
-331
lines changed

2 files changed

+22
-331
lines changed

source/api_cc/include/DeepPot.h

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -480,39 +480,39 @@ class DeepPotModelDevi {
480480
**/
481481
double cutoff() const {
482482
assert(inited);
483-
return rcut;
483+
return dps[0].cutoff();
484484
};
485485
/**
486486
* @brief Get the number of types.
487487
* @return The number of types.
488488
**/
489489
int numb_types() const {
490490
assert(inited);
491-
return ntypes;
491+
return dps[0].numb_types();
492492
};
493493
/**
494494
* @brief Get the number of types with spin.
495495
* @return The number of types with spin.
496496
**/
497497
int numb_types_spin() const {
498498
assert(inited);
499-
return ntypes_spin;
499+
return dps[0].numb_types_spin();
500500
};
501501
/**
502502
* @brief Get the dimension of the frame parameter.
503503
* @return The dimension of the frame parameter.
504504
**/
505505
int dim_fparam() const {
506506
assert(inited);
507-
return dfparam;
507+
return dps[0].dim_fparam();
508508
};
509509
/**
510510
* @brief Get the dimension of the atomic parameter.
511511
* @return The dimension of the atomic parameter.
512512
**/
513513
int dim_aparam() const {
514514
assert(inited);
515-
return daparam;
515+
return dps[0].dim_aparam();
516516
};
517517
/**
518518
* @brief Compute the average energy.
@@ -590,39 +590,12 @@ class DeepPotModelDevi {
590590
**/
591591
bool is_aparam_nall() const {
592592
assert(inited);
593-
return aparam_nall;
593+
return dps[0].is_aparam_nall();
594594
};
595595

596596
private:
597597
unsigned numb_models;
598-
std::vector<tensorflow::Session*> sessions;
599-
int num_intra_nthreads, num_inter_nthreads;
600-
std::vector<tensorflow::GraphDef*> graph_defs;
598+
std::vector<deepmd::DeepPot> dps;
601599
bool inited;
602-
template <class VT>
603-
VT get_scalar(const std::string name) const;
604-
// VALUETYPE get_rcut () const;
605-
// int get_ntypes () const;
606-
double rcut;
607-
double cell_size;
608-
int dtype;
609-
std::string model_type;
610-
std::string model_version;
611-
int ntypes;
612-
int ntypes_spin;
613-
int dfparam;
614-
int daparam;
615-
bool aparam_nall;
616-
template <typename VALUETYPE>
617-
void validate_fparam_aparam(const int& nloc,
618-
const std::vector<VALUETYPE>& fparam,
619-
const std::vector<VALUETYPE>& aparam) const;
620-
621-
// copy neighbor list info from host
622-
bool init_nbor;
623-
std::vector<std::vector<int> > sec;
624-
deepmd::AtomMap atommap;
625-
NeighborListData nlist_data;
626-
InputNlist nlist;
627600
};
628601
} // namespace deepmd

0 commit comments

Comments
 (0)