Skip to content

Commit 850575a

Browse files
authored
cc: refactor DeepTensor for multiple-backend framework (#3151)
See #3119 --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 9be1ad2 commit 850575a

File tree

4 files changed

+1512
-712
lines changed

4 files changed

+1512
-712
lines changed

source/api_cc/include/DeepTensor.h

Lines changed: 179 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,184 @@
11
// SPDX-License-Identifier: LGPL-3.0-or-later
22
#pragma once
33

4+
#include <memory>
5+
46
#include "common.h"
57
#include "neighbor_list.h"
68

79
namespace deepmd {
10+
/**
11+
* @brief Deep Tensor.
12+
**/
13+
class DeepTensorBase {
14+
public:
15+
/**
16+
* @brief Deep Tensor constructor without initialization.
17+
**/
18+
DeepTensorBase(){};
19+
virtual ~DeepTensorBase(){};
20+
/**
21+
* @brief Deep Tensor constructor with initialization..
22+
* @param[in] model The name of the frozen model file.
23+
* @param[in] gpu_rank The GPU rank. Default is 0.
24+
* @param[in] name_scope Name scopes of operations.
25+
**/
26+
DeepTensorBase(const std::string& model,
27+
const int& gpu_rank = 0,
28+
const std::string& name_scope = "");
29+
/**
30+
* @brief Initialize the Deep Tensor.
31+
* @param[in] model The name of the frozen model file.
32+
* @param[in] gpu_rank The GPU rank. Default is 0.
33+
* @param[in] name_scope Name scopes of operations.
34+
**/
35+
virtual void init(const std::string& model,
36+
const int& gpu_rank = 0,
37+
const std::string& name_scope = "") = 0;
38+
39+
/**
40+
* @brief Evaluate the value by using this model.
41+
* @param[out] value The value to evalute, usually would be the atomic tensor.
42+
* @param[in] coord The coordinates of atoms. The array should be of size
43+
*natoms x 3.
44+
* @param[in] atype The atom types. The list should contain natoms ints.
45+
* @param[in] box The cell of the region. The array should be of size 9.
46+
* @{
47+
**/
48+
virtual void computew(std::vector<double>& value,
49+
const std::vector<double>& coord,
50+
const std::vector<int>& atype,
51+
const std::vector<double>& box) = 0;
52+
virtual void computew(std::vector<float>& value,
53+
const std::vector<float>& coord,
54+
const std::vector<int>& atype,
55+
const std::vector<float>& box) = 0;
56+
/** @} */
57+
/**
58+
* @brief Evaluate the value by using this model.
59+
* @param[out] value The value to evalute, usually would be the atomic tensor.
60+
* @param[in] coord The coordinates of atoms. The array should be of size
61+
*natoms x 3.
62+
* @param[in] atype The atom types. The list should contain natoms ints.
63+
* @param[in] box The cell of the region. The array should be of size 9.
64+
* @param[in] nghost The number of ghost atoms.
65+
* @param[in] inlist The input neighbour list.
66+
* @{
67+
**/
68+
virtual void computew(std::vector<double>& value,
69+
const std::vector<double>& coord,
70+
const std::vector<int>& atype,
71+
const std::vector<double>& box,
72+
const int nghost,
73+
const InputNlist& inlist) = 0;
74+
virtual void computew(std::vector<float>& value,
75+
const std::vector<float>& coord,
76+
const std::vector<int>& atype,
77+
const std::vector<float>& box,
78+
const int nghost,
79+
const InputNlist& inlist) = 0;
80+
/** @} */
81+
/**
82+
* @brief Evaluate the global tensor and component-wise force and virial.
83+
* @param[out] global_tensor The global tensor to evalute.
84+
* @param[out] force The component-wise force of the global tensor, size odim
85+
*x natoms x 3.
86+
* @param[out] virial The component-wise virial of the global tensor, size
87+
*odim x 9.
88+
* @param[out] atom_tensor The atomic tensor value of the model, size natoms x
89+
*odim.
90+
* @param[out] atom_virial The component-wise atomic virial of the global
91+
*tensor, size odim x natoms x 9.
92+
* @param[in] coord The coordinates of atoms. The array should be of size
93+
*natoms x 3.
94+
* @param[in] atype The atom types. The list should contain natoms ints.
95+
* @param[in] box The cell of the region. The array should be of size 9.
96+
* @{
97+
**/
98+
virtual void computew(std::vector<double>& global_tensor,
99+
std::vector<double>& force,
100+
std::vector<double>& virial,
101+
std::vector<double>& atom_tensor,
102+
std::vector<double>& atom_virial,
103+
const std::vector<double>& coord,
104+
const std::vector<int>& atype,
105+
const std::vector<double>& box) = 0;
106+
virtual void computew(std::vector<float>& global_tensor,
107+
std::vector<float>& force,
108+
std::vector<float>& virial,
109+
std::vector<float>& atom_tensor,
110+
std::vector<float>& atom_virial,
111+
const std::vector<float>& coord,
112+
const std::vector<int>& atype,
113+
const std::vector<float>& box) = 0;
114+
/** @} */
115+
/**
116+
* @brief Evaluate the global tensor and component-wise force and virial.
117+
* @param[out] global_tensor The global tensor to evalute.
118+
* @param[out] force The component-wise force of the global tensor, size odim
119+
*x natoms x 3.
120+
* @param[out] virial The component-wise virial of the global tensor, size
121+
*odim x 9.
122+
* @param[out] atom_tensor The atomic tensor value of the model, size natoms x
123+
*odim.
124+
* @param[out] atom_virial The component-wise atomic virial of the global
125+
*tensor, size odim x natoms x 9.
126+
* @param[in] coord The coordinates of atoms. The array should be of size
127+
*natoms x 3.
128+
* @param[in] atype The atom types. The list should contain natoms ints.
129+
* @param[in] box The cell of the region. The array should be of size 9.
130+
* @param[in] nghost The number of ghost atoms.
131+
* @param[in] inlist The input neighbour list.
132+
* @{
133+
**/
134+
virtual void computew(std::vector<double>& global_tensor,
135+
std::vector<double>& force,
136+
std::vector<double>& virial,
137+
std::vector<double>& atom_tensor,
138+
std::vector<double>& atom_virial,
139+
const std::vector<double>& coord,
140+
const std::vector<int>& atype,
141+
const std::vector<double>& box,
142+
const int nghost,
143+
const InputNlist& inlist) = 0;
144+
virtual void computew(std::vector<float>& global_tensor,
145+
std::vector<float>& force,
146+
std::vector<float>& virial,
147+
std::vector<float>& atom_tensor,
148+
std::vector<float>& atom_virial,
149+
const std::vector<float>& coord,
150+
const std::vector<int>& atype,
151+
const std::vector<float>& box,
152+
const int nghost,
153+
const InputNlist& inlist) = 0;
154+
/** @} */
155+
/**
156+
* @brief Get the cutoff radius.
157+
* @return The cutoff radius.
158+
**/
159+
virtual double cutoff() const = 0;
160+
/**
161+
* @brief Get the number of types.
162+
* @return The number of types.
163+
**/
164+
virtual int numb_types() const = 0;
165+
/**
166+
* @brief Get the output dimension.
167+
* @return The output dimension.
168+
**/
169+
virtual int output_dim() const = 0;
170+
/**
171+
* @brief Get the list of sel types.
172+
* @return The list of sel types.
173+
*/
174+
virtual const std::vector<int>& sel_types() const = 0;
175+
/**
176+
* @brief Get the type map (element name of the atom types) of this model.
177+
* @param[out] type_map The type map of this model.
178+
**/
179+
virtual void get_type_map(std::string& type_map) = 0;
180+
};
181+
8182
/**
9183
* @brief Deep Tensor.
10184
**/
@@ -169,109 +343,30 @@ class DeepTensor {
169343
* @brief Get the cutoff radius.
170344
* @return The cutoff radius.
171345
**/
172-
double cutoff() const {
173-
assert(inited);
174-
return rcut;
175-
};
346+
double cutoff() const;
176347
/**
177348
* @brief Get the number of types.
178349
* @return The number of types.
179350
**/
180-
int numb_types() const {
181-
assert(inited);
182-
return ntypes;
183-
};
351+
int numb_types() const;
184352
/**
185353
* @brief Get the output dimension.
186354
* @return The output dimension.
187355
**/
188-
int output_dim() const {
189-
assert(inited);
190-
return odim;
191-
};
356+
int output_dim() const;
192357
/**
193358
* @brief Get the list of sel types.
194359
* @return The list of sel types.
195360
*/
196-
const std::vector<int>& sel_types() const {
197-
assert(inited);
198-
return sel_type;
199-
};
361+
const std::vector<int>& sel_types() const;
200362
/**
201363
* @brief Get the type map (element name of the atom types) of this model.
202364
* @param[out] type_map The type map of this model.
203365
**/
204366
void get_type_map(std::string& type_map);
205367

206368
private:
207-
tensorflow::Session* session;
208-
std::string name_scope;
209-
int num_intra_nthreads, num_inter_nthreads;
210-
tensorflow::GraphDef* graph_def;
211369
bool inited;
212-
double rcut;
213-
int dtype;
214-
double cell_size;
215-
int ntypes;
216-
std::string model_type;
217-
std::string model_version;
218-
int odim;
219-
std::vector<int> sel_type;
220-
template <class VT>
221-
VT get_scalar(const std::string& name) const;
222-
template <class VT>
223-
void get_vector(std::vector<VT>& vec, const std::string& name) const;
224-
template <typename MODELTYPE, typename VALUETYPE>
225-
void run_model(std::vector<VALUETYPE>& d_tensor_,
226-
tensorflow::Session* session,
227-
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
228-
input_tensors,
229-
const AtomMap& atommap,
230-
const std::vector<int>& sel_fwd,
231-
const int nghost = 0);
232-
template <typename MODELTYPE, typename VALUETYPE>
233-
void run_model(std::vector<VALUETYPE>& dglobal_tensor_,
234-
std::vector<VALUETYPE>& dforce_,
235-
std::vector<VALUETYPE>& dvirial_,
236-
std::vector<VALUETYPE>& datom_tensor_,
237-
std::vector<VALUETYPE>& datom_virial_,
238-
tensorflow::Session* session,
239-
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
240-
input_tensors,
241-
const AtomMap& atommap,
242-
const std::vector<int>& sel_fwd,
243-
const int nghost = 0);
244-
template <typename VALUETYPE>
245-
void compute_inner(std::vector<VALUETYPE>& value,
246-
const std::vector<VALUETYPE>& coord,
247-
const std::vector<int>& atype,
248-
const std::vector<VALUETYPE>& box);
249-
template <typename VALUETYPE>
250-
void compute_inner(std::vector<VALUETYPE>& value,
251-
const std::vector<VALUETYPE>& coord,
252-
const std::vector<int>& atype,
253-
const std::vector<VALUETYPE>& box,
254-
const int nghost,
255-
const InputNlist& inlist);
256-
template <typename VALUETYPE>
257-
void compute_inner(std::vector<VALUETYPE>& global_tensor,
258-
std::vector<VALUETYPE>& force,
259-
std::vector<VALUETYPE>& virial,
260-
std::vector<VALUETYPE>& atom_tensor,
261-
std::vector<VALUETYPE>& atom_virial,
262-
const std::vector<VALUETYPE>& coord,
263-
const std::vector<int>& atype,
264-
const std::vector<VALUETYPE>& box);
265-
template <typename VALUETYPE>
266-
void compute_inner(std::vector<VALUETYPE>& global_tensor,
267-
std::vector<VALUETYPE>& force,
268-
std::vector<VALUETYPE>& virial,
269-
std::vector<VALUETYPE>& atom_tensor,
270-
std::vector<VALUETYPE>& atom_virial,
271-
const std::vector<VALUETYPE>& coord,
272-
const std::vector<int>& atype,
273-
const std::vector<VALUETYPE>& box,
274-
const int nghost,
275-
const InputNlist& inlist);
370+
std::shared_ptr<deepmd::DeepTensorBase> dt;
276371
};
277372
} // namespace deepmd

0 commit comments

Comments
 (0)