Skip to content

Commit 2096b80

Browse files
authored
move utility to deepmd_utils (without modifaction) (#3140)
Move framework-independent codes to the `deepmd_utils` module without modification, as a step of #3118. --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent ae90498 commit 2096b80

31 files changed

+5227
-4985
lines changed

deepmd/common.py

Lines changed: 35 additions & 236 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,65 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""Collection of functions and classes used throughout the whole package."""
33

4-
import json
54
import warnings
65
from functools import (
76
wraps,
87
)
9-
from pathlib import (
10-
Path,
11-
)
128
from typing import (
139
TYPE_CHECKING,
1410
Any,
1511
Callable,
16-
Dict,
17-
List,
18-
Optional,
19-
TypeVar,
2012
Union,
2113
)
2214

23-
import numpy as np
2415
import tensorflow
25-
import yaml
2616
from tensorflow.python.framework import (
2717
tensor_util,
2818
)
2919

3020
from deepmd.env import (
31-
GLOBAL_NP_FLOAT_PRECISION,
3221
GLOBAL_TF_FLOAT_PRECISION,
3322
op_module,
3423
tf,
3524
)
36-
from deepmd.utils.path import (
37-
DPPath,
25+
from deepmd_utils.common import (
26+
add_data_requirement,
27+
data_requirement,
28+
expand_sys_str,
29+
get_np_precision,
30+
j_loader,
31+
j_must_have,
32+
make_default_mesh,
33+
select_idx_map,
3834
)
3935

4036
if TYPE_CHECKING:
41-
_DICT_VAL = TypeVar("_DICT_VAL")
42-
_OBJ = TypeVar("_OBJ")
43-
try:
44-
from typing import Literal # python >3.6
45-
except ImportError:
46-
from typing_extensions import Literal # type: ignore
47-
_ACTIVATION = Literal[
48-
"relu", "relu6", "softplus", "sigmoid", "tanh", "gelu", "gelu_tf"
49-
]
50-
_PRECISION = Literal["default", "float16", "float32", "float64"]
37+
from deepmd_utils.common import (
38+
_ACTIVATION,
39+
_PRECISION,
40+
)
41+
42+
__all__ = [
43+
# from deepmd_utils.common
44+
"data_requirement",
45+
"add_data_requirement",
46+
"select_idx_map",
47+
"make_default_mesh",
48+
"j_must_have",
49+
"j_loader",
50+
"expand_sys_str",
51+
"get_np_precision",
52+
# from self
53+
"PRECISION_DICT",
54+
"gelu",
55+
"gelu_tf",
56+
"ACTIVATION_FN_DICT",
57+
"get_activation_func",
58+
"get_precision",
59+
"safe_cast_tensor",
60+
"cast_precision",
61+
"clear_session",
62+
]
5163

5264
# define constants
5365
PRECISION_DICT = {
@@ -115,10 +127,6 @@ def gelu_wrapper(x):
115127
return (lambda x: gelu_wrapper(x))(x)
116128

117129

118-
# TODO this is not a good way to do things. This is some global variable to which
119-
# TODO anyone can write and there is no good way to keep track of the changes
120-
data_requirement = {}
121-
122130
ACTIVATION_FN_DICT = {
123131
"relu": tf.nn.relu,
124132
"relu6": tf.nn.relu6,
@@ -132,164 +140,6 @@ def gelu_wrapper(x):
132140
}
133141

134142

135-
def add_data_requirement(
136-
key: str,
137-
ndof: int,
138-
atomic: bool = False,
139-
must: bool = False,
140-
high_prec: bool = False,
141-
type_sel: Optional[bool] = None,
142-
repeat: int = 1,
143-
default: float = 0.0,
144-
dtype: Optional[np.dtype] = None,
145-
):
146-
"""Specify data requirements for training.
147-
148-
Parameters
149-
----------
150-
key : str
151-
type of data stored in corresponding `*.npy` file e.g. `forces` or `energy`
152-
ndof : int
153-
number of the degrees of freedom, this is tied to `atomic` parameter e.g. forces
154-
have `atomic=True` and `ndof=3`
155-
atomic : bool, optional
156-
specifies whwther the `ndof` keyworrd applies to per atom quantity or not,
157-
by default False
158-
must : bool, optional
159-
specifi if the `*.npy` data file must exist, by default False
160-
high_prec : bool, optional
161-
if true load data to `np.float64` else `np.float32`, by default False
162-
type_sel : bool, optional
163-
select only certain type of atoms, by default None
164-
repeat : int, optional
165-
if specify repaeat data `repeat` times, by default 1
166-
default : float, optional, default=0.
167-
default value of data
168-
dtype : np.dtype, optional
169-
the dtype of data, overwrites `high_prec` if provided
170-
"""
171-
data_requirement[key] = {
172-
"ndof": ndof,
173-
"atomic": atomic,
174-
"must": must,
175-
"high_prec": high_prec,
176-
"type_sel": type_sel,
177-
"repeat": repeat,
178-
"default": default,
179-
"dtype": dtype,
180-
}
181-
182-
183-
def select_idx_map(atom_types: np.ndarray, select_types: np.ndarray) -> np.ndarray:
184-
"""Build map of indices for element supplied element types from all atoms list.
185-
186-
Parameters
187-
----------
188-
atom_types : np.ndarray
189-
array specifing type for each atoms as integer
190-
select_types : np.ndarray
191-
types of atoms you want to find indices for
192-
193-
Returns
194-
-------
195-
np.ndarray
196-
indices of types of atoms defined by `select_types` in `atom_types` array
197-
198-
Warnings
199-
--------
200-
`select_types` array will be sorted before finding indices in `atom_types`
201-
"""
202-
sort_select_types = np.sort(select_types)
203-
idx_map = []
204-
for ii in sort_select_types:
205-
idx_map.append(np.where(atom_types == ii)[0])
206-
return np.concatenate(idx_map)
207-
208-
209-
def make_default_mesh(pbc: bool, mixed_type: bool) -> np.ndarray:
210-
"""Make mesh.
211-
212-
Only the size of mesh matters, not the values:
213-
* 6 for PBC, no mixed types
214-
* 0 for no PBC, no mixed types
215-
* 7 for PBC, mixed types
216-
* 1 for no PBC, mixed types
217-
218-
Parameters
219-
----------
220-
pbc : bool
221-
if True, the mesh will be made for periodic boundary conditions
222-
mixed_type : bool
223-
if True, the mesh will be made for mixed types
224-
225-
Returns
226-
-------
227-
np.ndarray
228-
mesh
229-
"""
230-
mesh_size = int(pbc) * 6 + int(mixed_type)
231-
default_mesh = np.zeros(mesh_size, dtype=np.int32)
232-
return default_mesh
233-
234-
235-
# TODO maybe rename this to j_deprecated and only warn about deprecated keys,
236-
# TODO if the deprecated_key argument is left empty function puppose is only custom
237-
# TODO error since dict[key] already raises KeyError when the key is missing
238-
def j_must_have(
239-
jdata: Dict[str, "_DICT_VAL"], key: str, deprecated_key: List[str] = []
240-
) -> "_DICT_VAL":
241-
"""Assert that supplied dictionary conaines specified key.
242-
243-
Returns
244-
-------
245-
_DICT_VAL
246-
value that was store unde supplied key
247-
248-
Raises
249-
------
250-
RuntimeError
251-
if the key is not present
252-
"""
253-
if key not in jdata.keys():
254-
for ii in deprecated_key:
255-
if ii in jdata.keys():
256-
warnings.warn(f"the key {ii} is deprecated, please use {key} instead")
257-
return jdata[ii]
258-
else:
259-
raise RuntimeError(f"json database must provide key {key}")
260-
else:
261-
return jdata[key]
262-
263-
264-
def j_loader(filename: Union[str, Path]) -> Dict[str, Any]:
265-
"""Load yaml or json settings file.
266-
267-
Parameters
268-
----------
269-
filename : Union[str, Path]
270-
path to file
271-
272-
Returns
273-
-------
274-
Dict[str, Any]
275-
loaded dictionary
276-
277-
Raises
278-
------
279-
TypeError
280-
if the supplied file is of unsupported type
281-
"""
282-
filepath = Path(filename)
283-
if filepath.suffix.endswith("json"):
284-
with filepath.open() as fp:
285-
return json.load(fp)
286-
elif filepath.suffix.endswith(("yml", "yaml")):
287-
with filepath.open() as fp:
288-
return yaml.safe_load(fp)
289-
else:
290-
raise TypeError("config file must be json, or yaml/yml")
291-
292-
293143
def get_activation_func(
294144
activation_fn: Union["_ACTIVATION", None],
295145
) -> Union[Callable[[tf.Tensor], tf.Tensor], None]:
@@ -340,57 +190,6 @@ def get_precision(precision: "_PRECISION") -> Any:
340190
return PRECISION_DICT[precision]
341191

342192

343-
# TODO port completely to pathlib when all callers are ported
344-
def expand_sys_str(root_dir: Union[str, Path]) -> List[str]:
345-
"""Recursively iterate over directories taking those that contain `type.raw` file.
346-
347-
Parameters
348-
----------
349-
root_dir : Union[str, Path]
350-
starting directory
351-
352-
Returns
353-
-------
354-
List[str]
355-
list of string pointing to system directories
356-
"""
357-
root_dir = DPPath(root_dir)
358-
matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()]
359-
if (root_dir / "type.raw").is_file():
360-
matches.append(str(root_dir))
361-
return matches
362-
363-
364-
def get_np_precision(precision: "_PRECISION") -> np.dtype:
365-
"""Get numpy precision constant from string.
366-
367-
Parameters
368-
----------
369-
precision : _PRECISION
370-
string name of numpy constant or default
371-
372-
Returns
373-
-------
374-
np.dtype
375-
numpy presicion constant
376-
377-
Raises
378-
------
379-
RuntimeError
380-
if string is invalid
381-
"""
382-
if precision == "default":
383-
return GLOBAL_NP_FLOAT_PRECISION
384-
elif precision == "float16":
385-
return np.float16
386-
elif precision == "float32":
387-
return np.float32
388-
elif precision == "float64":
389-
return np.float64
390-
else:
391-
raise RuntimeError(f"{precision} is not a valid precision")
392-
393-
394193
def safe_cast_tensor(
395194
input: tf.Tensor, from_precision: tf.DType, to_precision: tf.DType
396195
) -> tf.Tensor:

deepmd/env.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
)
2929

3030
import deepmd.lib
31+
from deepmd_utils.env import (
32+
GLOBAL_ENER_FLOAT_PRECISION,
33+
GLOBAL_NP_FLOAT_PRECISION,
34+
global_float_prec,
35+
)
3136

3237
if TYPE_CHECKING:
3338
from types import (
@@ -475,24 +480,7 @@ def _get_package_constants(
475480
op_grads_module = get_module("op_grads")
476481

477482
# FLOAT_PREC
478-
dp_float_prec = os.environ.get("DP_INTERFACE_PREC", "high").lower()
479-
if dp_float_prec in ("high", ""):
480-
# default is high
481-
GLOBAL_TF_FLOAT_PRECISION = tf.float64
482-
GLOBAL_NP_FLOAT_PRECISION = np.float64
483-
GLOBAL_ENER_FLOAT_PRECISION = np.float64
484-
global_float_prec = "double"
485-
elif dp_float_prec == "low":
486-
GLOBAL_TF_FLOAT_PRECISION = tf.float32
487-
GLOBAL_NP_FLOAT_PRECISION = np.float32
488-
GLOBAL_ENER_FLOAT_PRECISION = np.float64
489-
global_float_prec = "float"
490-
else:
491-
raise RuntimeError(
492-
"Unsupported float precision option: %s. Supported: high,"
493-
"low. Please set precision with environmental variable "
494-
"DP_INTERFACE_PREC." % dp_float_prec
495-
)
483+
GLOBAL_TF_FLOAT_PRECISION = tf.dtypes.as_dtype(GLOBAL_NP_FLOAT_PRECISION)
496484

497485

498486
def global_cvt_2_tf_float(xx: tf.Tensor) -> tf.Tensor:

deepmd/loggers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
"""Module taking care of logging duties."""
2+
"""Alias of deepmd_utils.loggers for backward compatibility."""
33

4-
from .loggers import (
4+
from deepmd_utils.loggers.loggers import (
55
set_log_handles,
66
)
77

0 commit comments

Comments
 (0)