Skip to content

Commit c824ff6

Browse files
authored
add ABC for descriptors (#1081)
* add ABC for descriptors I'm going to add abstract base classes for different object, where a list of methods and attributes is defined to normalize classes and their external call by other classes. It's also useful to develop and extend new classes. The first one I did is the descriptor. * TYPE_CHECKING doesn't work in python 3.6 * fix warnings
1 parent da5f688 commit c824ff6

File tree

8 files changed

+296
-7
lines changed

8 files changed

+296
-7
lines changed

deepmd/descriptor/descriptor.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, List, Tuple
3+
4+
import numpy as np
5+
from deepmd.env import tf
6+
7+
8+
class Descriptor(ABC):
9+
r"""The abstract class for descriptors. All specific descriptors should
10+
be based on this class.
11+
12+
The descriptor :math:`\mathcal{D}` describes the environment of an atom,
13+
which should be a function of coordinates and types of its neighbour atoms.
14+
15+
Notes
16+
-----
17+
Only methods and attributes defined in this class are generally public,
18+
that can be called by other classes.
19+
"""
20+
21+
@abstractmethod
22+
def get_rcut(self) -> float:
23+
"""
24+
Returns the cut-off radius.
25+
26+
Returns
27+
-------
28+
float
29+
the cut-off radius
30+
31+
Notes
32+
-----
33+
This method must be implemented, as it's called by other classes.
34+
"""
35+
36+
@abstractmethod
37+
def get_ntypes(self) -> int:
38+
"""
39+
Returns the number of atom types.
40+
41+
Returns
42+
-------
43+
int
44+
the number of atom types
45+
46+
Notes
47+
-----
48+
This method must be implemented, as it's called by other classes.
49+
"""
50+
51+
@abstractmethod
52+
def get_dim_out(self) -> int:
53+
"""
54+
Returns the output dimension of this descriptor.
55+
56+
Returns
57+
-------
58+
int
59+
the output dimension of this descriptor
60+
61+
Notes
62+
-----
63+
This method must be implemented, as it's called by other classes.
64+
"""
65+
66+
def get_dim_rot_mat_1(self) -> int:
67+
"""
68+
Returns the first dimension of the rotation matrix. The rotation is of shape
69+
dim_1 x 3
70+
71+
Returns
72+
-------
73+
int
74+
the first dimension of the rotation matrix
75+
"""
76+
# TODO: I think this method should be implemented as it's called by dipole and
77+
# polar fitting network. However, currently not all descriptors have this
78+
# method.
79+
raise NotImplementedError
80+
81+
def get_nlist(self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]:
82+
"""
83+
Returns neighbor information.
84+
85+
Returns
86+
-------
87+
nlist : tf.Tensor
88+
Neighbor list
89+
rij : tf.Tensor
90+
The relative distance between the neighbor and the center atom.
91+
sel_a : list[int]
92+
The number of neighbors with full information
93+
sel_r : list[int]
94+
The number of neighbors with only radial information
95+
"""
96+
# TODO: I think this method should be implemented as it's called by energy
97+
# model. However, se_ar and hybrid doesn't have this method.
98+
raise NotImplementedError
99+
100+
@abstractmethod
101+
def compute_input_stats(self,
102+
data_coord: List[np.ndarray],
103+
data_box: List[np.ndarray],
104+
data_atype: List[np.ndarray],
105+
natoms_vec: List[np.ndarray],
106+
mesh: List[np.ndarray],
107+
input_dict: Dict[str, List[np.ndarray]]
108+
) -> None:
109+
"""
110+
Compute the statisitcs (avg and std) of the training data. The input will be
111+
normalized by the statistics.
112+
113+
Parameters
114+
----------
115+
data_coord : list[np.ndarray]
116+
The coordinates. Can be generated by
117+
:meth:`deepmd.model.model_stat.make_stat_input`
118+
data_box : list[np.ndarray]
119+
The box. Can be generated by
120+
:meth:`deepmd.model.model_stat.make_stat_input`
121+
data_atype : list[np.ndarray]
122+
The atom types. Can be generated by :meth:`deepmd.model.model_stat.make_stat_input`
123+
natoms_vec : list[np.ndarray]
124+
The vector for the number of atoms of the system and different types of
125+
atoms. Can be generated by :meth:`deepmd.model.model_stat.make_stat_input`
126+
mesh : list[np.ndarray]
127+
The mesh for neighbor searching. Can be generated by
128+
:meth:`deepmd.model.model_stat.make_stat_input`
129+
input_dict : dict[str, list[np.ndarray]]
130+
Dictionary for additional input
131+
132+
Notes
133+
-----
134+
This method must be implemented, as it's called by other classes.
135+
"""
136+
137+
@abstractmethod
138+
def build(self,
139+
coord_: tf.Tensor,
140+
atype_: tf.Tensor,
141+
natoms: tf.Tensor,
142+
box_: tf.Tensor,
143+
mesh: tf.Tensor,
144+
input_dict: Dict[str, Any],
145+
reuse: bool = None,
146+
suffix: str = '',
147+
) -> tf.Tensor:
148+
"""
149+
Build the computational graph for the descriptor.
150+
151+
Parameters
152+
----------
153+
coord_ : tf.Tensor
154+
The coordinate of atoms
155+
atype_ : tf.Tensor
156+
The type of atoms
157+
natoms : tf.Tensor
158+
The number of atoms. This tensor has the length of Ntypes + 2
159+
natoms[0]: number of local atoms
160+
natoms[1]: total number of atoms held by this processor
161+
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
162+
box : tf.Tensor
163+
The box of frames
164+
mesh : tf.Tensor
165+
For historical reasons, only the length of the Tensor matters.
166+
if size of mesh == 6, pbc is assumed.
167+
if size of mesh == 0, no-pbc is assumed.
168+
input_dict : dict[str, Any]
169+
Dictionary for additional inputs
170+
reuse : bool, optional
171+
The weights in the networks should be reused when get the variable.
172+
suffix : str, optional
173+
Name suffix to identify this descriptor
174+
175+
Returns
176+
-------
177+
descriptor: tf.Tensor
178+
The output descriptor
179+
180+
Notes
181+
-----
182+
This method must be implemented, as it's called by other classes.
183+
"""
184+
185+
def enable_compression(self,
186+
min_nbor_dist: float,
187+
model_file: str = 'frozon_model.pb',
188+
table_extrapolate: float = 5.,
189+
table_stride_1: float = 0.01,
190+
table_stride_2: float = 0.1,
191+
check_frequency: int = -1
192+
) -> None:
193+
"""
194+
Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the
195+
training data.
196+
197+
Parameters
198+
----------
199+
min_nbor_dist : float
200+
The nearest distance between atoms
201+
model_file : str, default: 'frozon_model.pb'
202+
The original frozen model, which will be compressed by the program
203+
table_extrapolate : float, default: 5.
204+
The scale of model extrapolation
205+
table_stride_1 : float, default: 0.01
206+
The uniform stride of the first table
207+
table_stride_2 : float, default: 0.1
208+
The uniform stride of the second table
209+
check_frequency : int, default: -1
210+
The overflow check frequency
211+
212+
Notes
213+
-----
214+
This method is called by others when the descriptor supported compression.
215+
"""
216+
raise NotImplementedError(
217+
"Descriptor %s doesn't support compression!" % self.__name__)
218+
219+
@abstractmethod
220+
def prod_force_virial(self,
221+
atom_ener: tf.Tensor,
222+
natoms: tf.Tensor
223+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
224+
"""
225+
Compute force and virial.
226+
227+
Parameters
228+
----------
229+
atom_ener : tf.Tensor
230+
The atomic energy
231+
natoms : tf.Tensor
232+
The number of atoms. This tensor has the length of Ntypes + 2
233+
natoms[0]: number of local atoms
234+
natoms[1]: total number of atoms held by this processor
235+
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
236+
237+
Returns
238+
-------
239+
force : tf.Tensor
240+
The force on atoms
241+
virial : tf.Tensor
242+
The total virial
243+
atom_virial : tf.Tensor
244+
The atomic virial
245+
"""
246+
247+
def get_feed_dict(self,
248+
coord_: tf.Tensor,
249+
atype_: tf.Tensor,
250+
natoms: tf.Tensor,
251+
box: tf.Tensor,
252+
mesh: tf.Tensor
253+
) -> Dict[str, tf.Tensor]:
254+
"""
255+
Generate the feed_dict for current descriptor
256+
257+
Parameters
258+
----------
259+
coord_ : tf.Tensor
260+
The coordinate of atoms
261+
atype_ : tf.Tensor
262+
The type of atoms
263+
natoms : tf.Tensor
264+
The number of atoms. This tensor has the length of Ntypes + 2
265+
natoms[0]: number of local atoms
266+
natoms[1]: total number of atoms held by this processor
267+
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
268+
box : tf.Tensor
269+
The box. Can be generated by deepmd.model.make_stat_input
270+
mesh : tf.Tensor
271+
For historical reasons, only the length of the Tensor matters.
272+
if size of mesh == 6, pbc is assumed.
273+
if size of mesh == 0, no-pbc is assumed.
274+
275+
Returns
276+
-------
277+
feed_dict : dict[str, tf.Tensor]
278+
The output feed_dict of current descriptor
279+
"""
280+
# TODO: currently only SeA has this method, but I think the method can be
281+
# moved here as it doesn't contain anything related to a specific descriptor
282+
raise NotImplementedError

deepmd/descriptor/hybrid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# from deepmd.descriptor import DescrptSeAEbd
1313
# from deepmd.descriptor import DescrptSeAEf
1414
# from deepmd.descriptor import DescrptSeR
15+
from .descriptor import Descriptor
1516
from .se_a import DescrptSeA
1617
from .se_r import DescrptSeR
1718
from .se_ar import DescrptSeAR
@@ -20,7 +21,7 @@
2021
from .se_a_ef import DescrptSeAEf
2122
from .loc_frame import DescrptLocFrame
2223

23-
class DescrptHybrid ():
24+
class DescrptHybrid (Descriptor):
2425
"""Concate a list of descriptors to form a new descriptor.
2526
2627
Parameters

deepmd/descriptor/loc_frame.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from deepmd.env import op_module
88
from deepmd.env import default_tf_session_config
99
from deepmd.utils.sess import run_sess
10+
from .descriptor import Descriptor
1011

11-
class DescrptLocFrame () :
12+
class DescrptLocFrame (Descriptor) :
1213
"""Defines a local frame at each atom, and the compute the descriptor as local
1314
coordinates under this frame.
1415

deepmd/descriptor/se_a.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from deepmd.utils.type_embed import embed_atom_type
1515
from deepmd.utils.sess import run_sess
1616
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
17+
from .descriptor import Descriptor
1718

18-
class DescrptSeA ():
19+
class DescrptSeA (Descriptor):
1920
r"""DeepPot-SE constructed from all information (both angular and radial) of
2021
atomic configurations. The embedding takes the distance between atoms as input.
2122

deepmd/descriptor/se_a_ef.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from deepmd.env import op_module
1111
from deepmd.env import default_tf_session_config
1212
from .se_a import DescrptSeA
13+
from .descriptor import Descriptor
1314

14-
class DescrptSeAEf ():
15+
class DescrptSeAEf (Descriptor):
1516
"""
1617
1718
Parameters

deepmd/descriptor/se_ar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from .se_a import DescrptSeA
66
from .se_r import DescrptSeR
77
from deepmd.env import op_module
8+
from .descriptor import Descriptor
89

9-
class DescrptSeAR ():
10+
class DescrptSeAR (Descriptor):
1011
def __init__ (self, jdata):
1112
args = ClassArg()\
1213
.add('a', dict, must = True) \

deepmd/descriptor/se_r.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from deepmd.env import default_tf_session_config
1111
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
1212
from deepmd.utils.sess import run_sess
13+
from .descriptor import Descriptor
1314

14-
class DescrptSeR ():
15+
class DescrptSeR (Descriptor):
1516
"""DeepPot-SE constructed from radial information of atomic configurations.
1617
1718
The embedding takes the distance between atoms as input.

deepmd/descriptor/se_t.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from deepmd.env import default_tf_session_config
1111
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
1212
from deepmd.utils.sess import run_sess
13+
from .descriptor import Descriptor
1314

14-
class DescrptSeT ():
15+
class DescrptSeT (Descriptor):
1516
"""DeepPot-SE constructed from all information (both angular and radial) of atomic
1617
configurations.
1718

0 commit comments

Comments
 (0)