11# SPDX-License-Identifier: LGPL-3.0-or-later
22"""Collection of functions and classes used throughout the whole package."""
33
4- import json
54import warnings
65from functools import (
76 wraps ,
87)
9- from pathlib import (
10- Path ,
11- )
128from typing import (
139 TYPE_CHECKING ,
1410 Any ,
1511 Callable ,
16- Dict ,
17- List ,
18- Optional ,
19- TypeVar ,
2012 Union ,
2113)
2214
23- import numpy as np
2415import tensorflow
25- import yaml
2616from tensorflow .python .framework import (
2717 tensor_util ,
2818)
2919
3020from deepmd .env import (
31- GLOBAL_NP_FLOAT_PRECISION ,
3221 GLOBAL_TF_FLOAT_PRECISION ,
3322 op_module ,
3423 tf ,
3524)
36- from deepmd .utils .path import (
37- DPPath ,
25+ from deepmd_utils .common import (
26+ add_data_requirement ,
27+ data_requirement ,
28+ expand_sys_str ,
29+ get_np_precision ,
30+ j_loader ,
31+ j_must_have ,
32+ make_default_mesh ,
33+ select_idx_map ,
3834)
3935
4036if TYPE_CHECKING :
41- _DICT_VAL = TypeVar ("_DICT_VAL" )
42- _OBJ = TypeVar ("_OBJ" )
43- try :
44- from typing import Literal # python >3.6
45- except ImportError :
46- from typing_extensions import Literal # type: ignore
47- _ACTIVATION = Literal [
48- "relu" , "relu6" , "softplus" , "sigmoid" , "tanh" , "gelu" , "gelu_tf"
49- ]
50- _PRECISION = Literal ["default" , "float16" , "float32" , "float64" ]
37+ from deepmd_utils .common import (
38+ _ACTIVATION ,
39+ _PRECISION ,
40+ )
41+
42+ __all__ = [
43+ # from deepmd_utils.common
44+ "data_requirement" ,
45+ "add_data_requirement" ,
46+ "select_idx_map" ,
47+ "make_default_mesh" ,
48+ "j_must_have" ,
49+ "j_loader" ,
50+ "expand_sys_str" ,
51+ "get_np_precision" ,
52+ # from self
53+ "PRECISION_DICT" ,
54+ "gelu" ,
55+ "gelu_tf" ,
56+ "ACTIVATION_FN_DICT" ,
57+ "get_activation_func" ,
58+ "get_precision" ,
59+ "safe_cast_tensor" ,
60+ "cast_precision" ,
61+ "clear_session" ,
62+ ]
5163
5264# define constants
5365PRECISION_DICT = {
@@ -115,10 +127,6 @@ def gelu_wrapper(x):
115127 return (lambda x : gelu_wrapper (x ))(x )
116128
117129
118- # TODO this is not a good way to do things. This is some global variable to which
119- # TODO anyone can write and there is no good way to keep track of the changes
120- data_requirement = {}
121-
122130ACTIVATION_FN_DICT = {
123131 "relu" : tf .nn .relu ,
124132 "relu6" : tf .nn .relu6 ,
@@ -132,164 +140,6 @@ def gelu_wrapper(x):
132140}
133141
134142
135- def add_data_requirement (
136- key : str ,
137- ndof : int ,
138- atomic : bool = False ,
139- must : bool = False ,
140- high_prec : bool = False ,
141- type_sel : Optional [bool ] = None ,
142- repeat : int = 1 ,
143- default : float = 0.0 ,
144- dtype : Optional [np .dtype ] = None ,
145- ):
146- """Specify data requirements for training.
147-
148- Parameters
149- ----------
150- key : str
151- type of data stored in corresponding `*.npy` file e.g. `forces` or `energy`
152- ndof : int
153- number of the degrees of freedom, this is tied to `atomic` parameter e.g. forces
154- have `atomic=True` and `ndof=3`
155- atomic : bool, optional
156- specifies whwther the `ndof` keyworrd applies to per atom quantity or not,
157- by default False
158- must : bool, optional
159- specifi if the `*.npy` data file must exist, by default False
160- high_prec : bool, optional
161- if true load data to `np.float64` else `np.float32`, by default False
162- type_sel : bool, optional
163- select only certain type of atoms, by default None
164- repeat : int, optional
165- if specify repaeat data `repeat` times, by default 1
166- default : float, optional, default=0.
167- default value of data
168- dtype : np.dtype, optional
169- the dtype of data, overwrites `high_prec` if provided
170- """
171- data_requirement [key ] = {
172- "ndof" : ndof ,
173- "atomic" : atomic ,
174- "must" : must ,
175- "high_prec" : high_prec ,
176- "type_sel" : type_sel ,
177- "repeat" : repeat ,
178- "default" : default ,
179- "dtype" : dtype ,
180- }
181-
182-
183- def select_idx_map (atom_types : np .ndarray , select_types : np .ndarray ) -> np .ndarray :
184- """Build map of indices for element supplied element types from all atoms list.
185-
186- Parameters
187- ----------
188- atom_types : np.ndarray
189- array specifing type for each atoms as integer
190- select_types : np.ndarray
191- types of atoms you want to find indices for
192-
193- Returns
194- -------
195- np.ndarray
196- indices of types of atoms defined by `select_types` in `atom_types` array
197-
198- Warnings
199- --------
200- `select_types` array will be sorted before finding indices in `atom_types`
201- """
202- sort_select_types = np .sort (select_types )
203- idx_map = []
204- for ii in sort_select_types :
205- idx_map .append (np .where (atom_types == ii )[0 ])
206- return np .concatenate (idx_map )
207-
208-
209- def make_default_mesh (pbc : bool , mixed_type : bool ) -> np .ndarray :
210- """Make mesh.
211-
212- Only the size of mesh matters, not the values:
213- * 6 for PBC, no mixed types
214- * 0 for no PBC, no mixed types
215- * 7 for PBC, mixed types
216- * 1 for no PBC, mixed types
217-
218- Parameters
219- ----------
220- pbc : bool
221- if True, the mesh will be made for periodic boundary conditions
222- mixed_type : bool
223- if True, the mesh will be made for mixed types
224-
225- Returns
226- -------
227- np.ndarray
228- mesh
229- """
230- mesh_size = int (pbc ) * 6 + int (mixed_type )
231- default_mesh = np .zeros (mesh_size , dtype = np .int32 )
232- return default_mesh
233-
234-
235- # TODO maybe rename this to j_deprecated and only warn about deprecated keys,
236- # TODO if the deprecated_key argument is left empty function puppose is only custom
237- # TODO error since dict[key] already raises KeyError when the key is missing
238- def j_must_have (
239- jdata : Dict [str , "_DICT_VAL" ], key : str , deprecated_key : List [str ] = []
240- ) -> "_DICT_VAL" :
241- """Assert that supplied dictionary conaines specified key.
242-
243- Returns
244- -------
245- _DICT_VAL
246- value that was store unde supplied key
247-
248- Raises
249- ------
250- RuntimeError
251- if the key is not present
252- """
253- if key not in jdata .keys ():
254- for ii in deprecated_key :
255- if ii in jdata .keys ():
256- warnings .warn (f"the key { ii } is deprecated, please use { key } instead" )
257- return jdata [ii ]
258- else :
259- raise RuntimeError (f"json database must provide key { key } " )
260- else :
261- return jdata [key ]
262-
263-
264- def j_loader (filename : Union [str , Path ]) -> Dict [str , Any ]:
265- """Load yaml or json settings file.
266-
267- Parameters
268- ----------
269- filename : Union[str, Path]
270- path to file
271-
272- Returns
273- -------
274- Dict[str, Any]
275- loaded dictionary
276-
277- Raises
278- ------
279- TypeError
280- if the supplied file is of unsupported type
281- """
282- filepath = Path (filename )
283- if filepath .suffix .endswith ("json" ):
284- with filepath .open () as fp :
285- return json .load (fp )
286- elif filepath .suffix .endswith (("yml" , "yaml" )):
287- with filepath .open () as fp :
288- return yaml .safe_load (fp )
289- else :
290- raise TypeError ("config file must be json, or yaml/yml" )
291-
292-
293143def get_activation_func (
294144 activation_fn : Union ["_ACTIVATION" , None ],
295145) -> Union [Callable [[tf .Tensor ], tf .Tensor ], None ]:
@@ -340,57 +190,6 @@ def get_precision(precision: "_PRECISION") -> Any:
340190 return PRECISION_DICT [precision ]
341191
342192
343- # TODO port completely to pathlib when all callers are ported
344- def expand_sys_str (root_dir : Union [str , Path ]) -> List [str ]:
345- """Recursively iterate over directories taking those that contain `type.raw` file.
346-
347- Parameters
348- ----------
349- root_dir : Union[str, Path]
350- starting directory
351-
352- Returns
353- -------
354- List[str]
355- list of string pointing to system directories
356- """
357- root_dir = DPPath (root_dir )
358- matches = [str (d ) for d in root_dir .rglob ("*" ) if (d / "type.raw" ).is_file ()]
359- if (root_dir / "type.raw" ).is_file ():
360- matches .append (str (root_dir ))
361- return matches
362-
363-
364- def get_np_precision (precision : "_PRECISION" ) -> np .dtype :
365- """Get numpy precision constant from string.
366-
367- Parameters
368- ----------
369- precision : _PRECISION
370- string name of numpy constant or default
371-
372- Returns
373- -------
374- np.dtype
375- numpy presicion constant
376-
377- Raises
378- ------
379- RuntimeError
380- if string is invalid
381- """
382- if precision == "default" :
383- return GLOBAL_NP_FLOAT_PRECISION
384- elif precision == "float16" :
385- return np .float16
386- elif precision == "float32" :
387- return np .float32
388- elif precision == "float64" :
389- return np .float64
390- else :
391- raise RuntimeError (f"{ precision } is not a valid precision" )
392-
393-
394193def safe_cast_tensor (
395194 input : tf .Tensor , from_precision : tf .DType , to_precision : tf .DType
396195) -> tf .Tensor :
0 commit comments