Skip to content

Commit 1efc7f8

Browse files
cc: refactor DataModifier for multiple-backend framework (#3148)
See #3119 --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 850575a commit 1efc7f8

File tree

4 files changed

+611
-319
lines changed

4 files changed

+611
-319
lines changed

source/api_cc/include/DataModifier.h

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,92 @@
11
// SPDX-License-Identifier: LGPL-3.0-or-later
22
#pragma once
33

4-
#include "DeepPot.h"
4+
#include <memory>
5+
6+
#include "common.h"
57

68
namespace deepmd {
9+
/**
10+
* @brief Dipole charge modifier. (Base class)
11+
**/
12+
class DipoleChargeModifierBase {
13+
public:
14+
/**
15+
* @brief Dipole charge modifier without initialization.
16+
**/
17+
DipoleChargeModifierBase(){};
18+
/**
19+
* @brief Dipole charge modifier without initialization.
20+
* @param[in] model The name of the frozen model file.
21+
* @param[in] gpu_rank The GPU rank. Default is 0.
22+
* @param[in] name_scope The name scope.
23+
**/
24+
DipoleChargeModifierBase(const std::string& model,
25+
const int& gpu_rank = 0,
26+
const std::string& name_scope = "");
27+
virtual ~DipoleChargeModifierBase(){};
28+
/**
29+
* @brief Initialize the dipole charge modifier.
30+
* @param[in] model The name of the frozen model file.
31+
* @param[in] gpu_rank The GPU rank. Default is 0.
32+
* @param[in] name_scope The name scope.
33+
**/
34+
virtual void init(const std::string& model,
35+
const int& gpu_rank = 0,
36+
const std::string& name_scope = "") = 0;
37+
/**
38+
* @brief Evaluate the force and virial correction by using this dipole charge
39+
*modifier.
40+
* @param[out] dfcorr_ The force correction on each atom.
41+
* @param[out] dvcorr_ The virial correction.
42+
* @param[in] dcoord_ The coordinates of atoms. The array should be of size
43+
*natoms x 3.
44+
* @param[in] datype_ The atom types. The list should contain natoms ints.
45+
* @param[in] dbox The cell of the region. The array should be of size 9.
46+
* @param[in] pairs The pairs of atoms. The list should contain npairs pairs
47+
*of ints.
48+
* @param[in] delef_ The electric field on each atom. The array should be of
49+
*size natoms x 3.
50+
* @param[in] nghost The number of ghost atoms.
51+
* @param[in] lmp_list The neighbor list.
52+
@{
53+
**/
54+
virtual void computew(std::vector<double>& dfcorr_,
55+
std::vector<double>& dvcorr_,
56+
const std::vector<double>& dcoord_,
57+
const std::vector<int>& datype_,
58+
const std::vector<double>& dbox,
59+
const std::vector<std::pair<int, int>>& pairs,
60+
const std::vector<double>& delef_,
61+
const int nghost,
62+
const InputNlist& lmp_list) = 0;
63+
virtual void computew(std::vector<float>& dfcorr_,
64+
std::vector<float>& dvcorr_,
65+
const std::vector<float>& dcoord_,
66+
const std::vector<int>& datype_,
67+
const std::vector<float>& dbox,
68+
const std::vector<std::pair<int, int>>& pairs,
69+
const std::vector<float>& delef_,
70+
const int nghost,
71+
const InputNlist& lmp_list) = 0;
72+
/** @} */
73+
/**
74+
* @brief Get cutoff radius.
75+
* @return double cutoff radius.
76+
*/
77+
virtual double cutoff() const = 0;
78+
/**
79+
* @brief Get the number of atom types.
80+
* @return int number of atom types.
81+
*/
82+
virtual int numb_types() const = 0;
83+
/**
84+
* @brief Get the list of sel types.
85+
* @return The list of sel types.
86+
*/
87+
virtual std::vector<int> sel_types() const = 0;
88+
};
89+
790
/**
891
* @brief Dipole charge modifier.
992
**/
@@ -38,7 +121,6 @@ class DipoleChargeModifier {
38121
**/
39122
void print_summary(const std::string& pre) const;
40123

41-
public:
42124
/**
43125
* @brief Evaluate the force and virial correction by using this dipole charge
44126
*modifier.
@@ -69,50 +151,20 @@ class DipoleChargeModifier {
69151
* @brief Get cutoff radius.
70152
* @return double cutoff radius.
71153
*/
72-
double cutoff() const {
73-
assert(inited);
74-
return rcut;
75-
};
154+
double cutoff() const;
76155
/**
77156
* @brief Get the number of atom types.
78157
* @return int number of atom types.
79158
*/
80-
int numb_types() const {
81-
assert(inited);
82-
return ntypes;
83-
};
159+
int numb_types() const;
84160
/**
85161
* @brief Get the list of sel types.
86162
* @return The list of sel types.
87163
*/
88-
std::vector<int> sel_types() const {
89-
assert(inited);
90-
return sel_type;
91-
};
164+
std::vector<int> sel_types() const;
92165

93166
private:
94-
tensorflow::Session* session;
95-
std::string name_scope, name_prefix;
96-
int num_intra_nthreads, num_inter_nthreads;
97-
tensorflow::GraphDef* graph_def;
98167
bool inited;
99-
double rcut;
100-
int dtype;
101-
double cell_size;
102-
int ntypes;
103-
std::string model_type;
104-
std::vector<int> sel_type;
105-
template <class VT>
106-
VT get_scalar(const std::string& name) const;
107-
template <class VT>
108-
void get_vector(std::vector<VT>& vec, const std::string& name) const;
109-
template <typename MODELTYPE, typename VALUETYPE>
110-
void run_model(std::vector<VALUETYPE>& dforce,
111-
std::vector<VALUETYPE>& dvirial,
112-
tensorflow::Session* session,
113-
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
114-
input_tensors,
115-
const AtomMap& atommap,
116-
const int nghost);
168+
std::shared_ptr<deepmd::DipoleChargeModifierBase> dcm;
117169
};
118170
} // namespace deepmd
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// SPDX-License-Identifier: LGPL-3.0-or-later
2+
#pragma once
3+
4+
#include "DataModifier.h"
5+
#include "common.h"
6+
7+
namespace deepmd {
8+
/**
9+
* @brief Dipole charge modifier.
10+
**/
11+
class DipoleChargeModifierTF : public DipoleChargeModifierBase {
12+
public:
13+
/**
14+
* @brief Dipole charge modifier without initialization.
15+
**/
16+
DipoleChargeModifierTF();
17+
/**
18+
* @brief Dipole charge modifier without initialization.
19+
* @param[in] model The name of the frozen model file.
20+
* @param[in] gpu_rank The GPU rank. Default is 0.
21+
* @param[in] name_scope The name scope.
22+
**/
23+
DipoleChargeModifierTF(const std::string& model,
24+
const int& gpu_rank = 0,
25+
const std::string& name_scope = "");
26+
~DipoleChargeModifierTF();
27+
/**
28+
* @brief Initialize the dipole charge modifier.
29+
* @param[in] model The name of the frozen model file.
30+
* @param[in] gpu_rank The GPU rank. Default is 0.
31+
* @param[in] name_scope The name scope.
32+
**/
33+
void init(const std::string& model,
34+
const int& gpu_rank = 0,
35+
const std::string& name_scope = "");
36+
37+
public:
38+
/**
39+
* @brief Evaluate the force and virial correction by using this dipole charge
40+
*modifier.
41+
* @param[out] dfcorr_ The force correction on each atom.
42+
* @param[out] dvcorr_ The virial correction.
43+
* @param[in] dcoord_ The coordinates of atoms. The array should be of size
44+
*natoms x 3.
45+
* @param[in] datype_ The atom types. The list should contain natoms ints.
46+
* @param[in] dbox The cell of the region. The array should be of size 9.
47+
* @param[in] pairs The pairs of atoms. The list should contain npairs pairs
48+
*of ints.
49+
* @param[in] delef_ The electric field on each atom. The array should be of
50+
*size natoms x 3.
51+
* @param[in] nghost The number of ghost atoms.
52+
* @param[in] lmp_list The neighbor list.
53+
**/
54+
template <typename VALUETYPE>
55+
void compute(std::vector<VALUETYPE>& dfcorr_,
56+
std::vector<VALUETYPE>& dvcorr_,
57+
const std::vector<VALUETYPE>& dcoord_,
58+
const std::vector<int>& datype_,
59+
const std::vector<VALUETYPE>& dbox,
60+
const std::vector<std::pair<int, int>>& pairs,
61+
const std::vector<VALUETYPE>& delef_,
62+
const int nghost,
63+
const InputNlist& lmp_list);
64+
/**
65+
* @brief Get cutoff radius.
66+
* @return double cutoff radius.
67+
*/
68+
double cutoff() const {
69+
assert(inited);
70+
return rcut;
71+
};
72+
/**
73+
* @brief Get the number of atom types.
74+
* @return int number of atom types.
75+
*/
76+
int numb_types() const {
77+
assert(inited);
78+
return ntypes;
79+
};
80+
/**
81+
* @brief Get the list of sel types.
82+
* @return The list of sel types.
83+
*/
84+
std::vector<int> sel_types() const {
85+
assert(inited);
86+
return sel_type;
87+
};
88+
void computew(std::vector<double>& dfcorr_,
89+
std::vector<double>& dvcorr_,
90+
const std::vector<double>& dcoord_,
91+
const std::vector<int>& datype_,
92+
const std::vector<double>& dbox,
93+
const std::vector<std::pair<int, int>>& pairs,
94+
const std::vector<double>& delef_,
95+
const int nghost,
96+
const InputNlist& lmp_list);
97+
void computew(std::vector<float>& dfcorr_,
98+
std::vector<float>& dvcorr_,
99+
const std::vector<float>& dcoord_,
100+
const std::vector<int>& datype_,
101+
const std::vector<float>& dbox,
102+
const std::vector<std::pair<int, int>>& pairs,
103+
const std::vector<float>& delef_,
104+
const int nghost,
105+
const InputNlist& lmp_list);
106+
107+
private:
108+
tensorflow::Session* session;
109+
std::string name_scope, name_prefix;
110+
int num_intra_nthreads, num_inter_nthreads;
111+
tensorflow::GraphDef* graph_def;
112+
bool inited;
113+
double rcut;
114+
int dtype;
115+
double cell_size;
116+
int ntypes;
117+
std::string model_type;
118+
std::vector<int> sel_type;
119+
template <class VT>
120+
VT get_scalar(const std::string& name) const;
121+
template <class VT>
122+
void get_vector(std::vector<VT>& vec, const std::string& name) const;
123+
template <typename MODELTYPE, typename VALUETYPE>
124+
void run_model(std::vector<VALUETYPE>& dforce,
125+
std::vector<VALUETYPE>& dvirial,
126+
tensorflow::Session* session,
127+
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
128+
input_tensors,
129+
const AtomMap& atommap,
130+
const int nghost);
131+
};
132+
} // namespace deepmd

0 commit comments

Comments
 (0)