|
1 | 1 | // SPDX-License-Identifier: LGPL-3.0-or-later |
2 | 2 | #pragma once |
3 | 3 |
|
4 | | -#include "DeepPot.h" |
| 4 | +#include <memory> |
| 5 | + |
| 6 | +#include "common.h" |
5 | 7 |
|
6 | 8 | namespace deepmd { |
| 9 | +/** |
| 10 | + * @brief Dipole charge modifier. (Base class) |
| 11 | + **/ |
| 12 | +class DipoleChargeModifierBase { |
| 13 | + public: |
| 14 | + /** |
| 15 | + * @brief Dipole charge modifier without initialization. |
| 16 | + **/ |
| 17 | + DipoleChargeModifierBase(){}; |
| 18 | + /** |
| 19 | + * @brief Dipole charge modifier without initialization. |
| 20 | + * @param[in] model The name of the frozen model file. |
| 21 | + * @param[in] gpu_rank The GPU rank. Default is 0. |
| 22 | + * @param[in] name_scope The name scope. |
| 23 | + **/ |
| 24 | + DipoleChargeModifierBase(const std::string& model, |
| 25 | + const int& gpu_rank = 0, |
| 26 | + const std::string& name_scope = ""); |
| 27 | + virtual ~DipoleChargeModifierBase(){}; |
| 28 | + /** |
| 29 | + * @brief Initialize the dipole charge modifier. |
| 30 | + * @param[in] model The name of the frozen model file. |
| 31 | + * @param[in] gpu_rank The GPU rank. Default is 0. |
| 32 | + * @param[in] name_scope The name scope. |
| 33 | + **/ |
| 34 | + virtual void init(const std::string& model, |
| 35 | + const int& gpu_rank = 0, |
| 36 | + const std::string& name_scope = "") = 0; |
| 37 | + /** |
| 38 | + * @brief Evaluate the force and virial correction by using this dipole charge |
| 39 | + *modifier. |
| 40 | + * @param[out] dfcorr_ The force correction on each atom. |
| 41 | + * @param[out] dvcorr_ The virial correction. |
| 42 | + * @param[in] dcoord_ The coordinates of atoms. The array should be of size |
| 43 | + *natoms x 3. |
| 44 | + * @param[in] datype_ The atom types. The list should contain natoms ints. |
| 45 | + * @param[in] dbox The cell of the region. The array should be of size 9. |
| 46 | + * @param[in] pairs The pairs of atoms. The list should contain npairs pairs |
| 47 | + *of ints. |
| 48 | + * @param[in] delef_ The electric field on each atom. The array should be of |
| 49 | + *size natoms x 3. |
| 50 | + * @param[in] nghost The number of ghost atoms. |
| 51 | + * @param[in] lmp_list The neighbor list. |
| 52 | + @{ |
| 53 | + **/ |
| 54 | + virtual void computew(std::vector<double>& dfcorr_, |
| 55 | + std::vector<double>& dvcorr_, |
| 56 | + const std::vector<double>& dcoord_, |
| 57 | + const std::vector<int>& datype_, |
| 58 | + const std::vector<double>& dbox, |
| 59 | + const std::vector<std::pair<int, int>>& pairs, |
| 60 | + const std::vector<double>& delef_, |
| 61 | + const int nghost, |
| 62 | + const InputNlist& lmp_list) = 0; |
| 63 | + virtual void computew(std::vector<float>& dfcorr_, |
| 64 | + std::vector<float>& dvcorr_, |
| 65 | + const std::vector<float>& dcoord_, |
| 66 | + const std::vector<int>& datype_, |
| 67 | + const std::vector<float>& dbox, |
| 68 | + const std::vector<std::pair<int, int>>& pairs, |
| 69 | + const std::vector<float>& delef_, |
| 70 | + const int nghost, |
| 71 | + const InputNlist& lmp_list) = 0; |
| 72 | + /** @} */ |
| 73 | + /** |
| 74 | + * @brief Get cutoff radius. |
| 75 | + * @return double cutoff radius. |
| 76 | + */ |
| 77 | + virtual double cutoff() const = 0; |
| 78 | + /** |
| 79 | + * @brief Get the number of atom types. |
| 80 | + * @return int number of atom types. |
| 81 | + */ |
| 82 | + virtual int numb_types() const = 0; |
| 83 | + /** |
| 84 | + * @brief Get the list of sel types. |
| 85 | + * @return The list of sel types. |
| 86 | + */ |
| 87 | + virtual std::vector<int> sel_types() const = 0; |
| 88 | +}; |
| 89 | + |
7 | 90 | /** |
8 | 91 | * @brief Dipole charge modifier. |
9 | 92 | **/ |
@@ -38,7 +121,6 @@ class DipoleChargeModifier { |
38 | 121 | **/ |
39 | 122 | void print_summary(const std::string& pre) const; |
40 | 123 |
|
41 | | - public: |
42 | 124 | /** |
43 | 125 | * @brief Evaluate the force and virial correction by using this dipole charge |
44 | 126 | *modifier. |
@@ -69,50 +151,20 @@ class DipoleChargeModifier { |
69 | 151 | * @brief Get cutoff radius. |
70 | 152 | * @return double cutoff radius. |
71 | 153 | */ |
72 | | - double cutoff() const { |
73 | | - assert(inited); |
74 | | - return rcut; |
75 | | - }; |
| 154 | + double cutoff() const; |
76 | 155 | /** |
77 | 156 | * @brief Get the number of atom types. |
78 | 157 | * @return int number of atom types. |
79 | 158 | */ |
80 | | - int numb_types() const { |
81 | | - assert(inited); |
82 | | - return ntypes; |
83 | | - }; |
| 159 | + int numb_types() const; |
84 | 160 | /** |
85 | 161 | * @brief Get the list of sel types. |
86 | 162 | * @return The list of sel types. |
87 | 163 | */ |
88 | | - std::vector<int> sel_types() const { |
89 | | - assert(inited); |
90 | | - return sel_type; |
91 | | - }; |
| 164 | + std::vector<int> sel_types() const; |
92 | 165 |
|
93 | 166 | private: |
94 | | - tensorflow::Session* session; |
95 | | - std::string name_scope, name_prefix; |
96 | | - int num_intra_nthreads, num_inter_nthreads; |
97 | | - tensorflow::GraphDef* graph_def; |
98 | 167 | bool inited; |
99 | | - double rcut; |
100 | | - int dtype; |
101 | | - double cell_size; |
102 | | - int ntypes; |
103 | | - std::string model_type; |
104 | | - std::vector<int> sel_type; |
105 | | - template <class VT> |
106 | | - VT get_scalar(const std::string& name) const; |
107 | | - template <class VT> |
108 | | - void get_vector(std::vector<VT>& vec, const std::string& name) const; |
109 | | - template <typename MODELTYPE, typename VALUETYPE> |
110 | | - void run_model(std::vector<VALUETYPE>& dforce, |
111 | | - std::vector<VALUETYPE>& dvirial, |
112 | | - tensorflow::Session* session, |
113 | | - const std::vector<std::pair<std::string, tensorflow::Tensor>>& |
114 | | - input_tensors, |
115 | | - const AtomMap& atommap, |
116 | | - const int nghost); |
| 168 | + std::shared_ptr<deepmd::DipoleChargeModifierBase> dcm; |
117 | 169 | }; |
118 | 170 | } // namespace deepmd |
0 commit comments