Skip to content

Commit f8bb170

Browse files
authored
feat: calculate the real error in dp model-devi (#2757)
This PR adds a new argument `dp model-devi --real_error`. With that, the real error given in the DP-GEN paper is calculated. ![Screenshot from 2023-08-23 15-43-57](https://github.com/deepmodeling/deepmd-kit/assets/9496702/331dd7fc-f1d4-4b37-be9f-80c78d027602) One can use this feature for `dpgen simplify` if the labeled data exists. Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 1db2195 commit f8bb170

File tree

3 files changed

+119
-10
lines changed

3 files changed

+119
-10
lines changed

deepmd/infer/model_devi.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,18 @@
2121
)
2222

2323

24-
def calc_model_devi_f(fs: np.ndarray) -> Tuple[np.ndarray]:
24+
def calc_model_devi_f(
25+
fs: np.ndarray, real_f: Optional[np.ndarray] = None
26+
) -> Tuple[np.ndarray]:
2527
"""Calculate model deviation of force.
2628
2729
Parameters
2830
----------
2931
fs : numpy.ndarray
3032
size of `n_models x n_frames x n_atoms x 3`
33+
real_f : numpy.ndarray or None
34+
real force, size of `n_frames x n_atoms x 3`. If given,
35+
the RMS real error is calculated instead.
3136
3237
Returns
3338
-------
@@ -38,14 +43,21 @@ def calc_model_devi_f(fs: np.ndarray) -> Tuple[np.ndarray]:
3843
avg_devi_f : numpy.ndarray
3944
average deviation of force in all atoms
4045
"""
41-
fs_devi = np.linalg.norm(np.std(fs, axis=0), axis=-1)
46+
if real_f is None:
47+
fs_devi = np.linalg.norm(np.std(fs, axis=0), axis=-1)
48+
else:
49+
fs_devi = np.linalg.norm(
50+
np.sqrt(np.mean(np.square(fs - real_f), axis=0)), axis=-1
51+
)
4252
max_devi_f = np.max(fs_devi, axis=-1)
4353
min_devi_f = np.min(fs_devi, axis=-1)
4454
avg_devi_f = np.mean(fs_devi, axis=-1)
4555
return max_devi_f, min_devi_f, avg_devi_f
4656

4757

48-
def calc_model_devi_e(es: np.ndarray) -> np.ndarray:
58+
def calc_model_devi_e(
59+
es: np.ndarray, real_e: Optional[np.ndarray] = None
60+
) -> np.ndarray:
4961
"""Calculate model deviation of total energy per atom.
5062
5163
Here we don't use the atomic energy, as the decomposition
@@ -56,24 +68,35 @@ def calc_model_devi_e(es: np.ndarray) -> np.ndarray:
5668
----------
5769
es : numpy.ndarray
5870
size of `n_models x n_frames x 1
71+
real_e : numpy.ndarray
72+
real energy, size of `n_frames x 1`. If given,
73+
the RMS real error is calculated instead.
5974
6075
Returns
6176
-------
6277
max_devi_e : numpy.ndarray
6378
maximum deviation of energy
6479
"""
65-
es_devi = np.std(es, axis=0)
80+
if real_e is None:
81+
es_devi = np.std(es, axis=0)
82+
else:
83+
es_devi = np.sqrt(np.mean(np.square(es - real_e), axis=0))
6684
es_devi = np.squeeze(es_devi, axis=-1)
6785
return es_devi
6886

6987

70-
def calc_model_devi_v(vs: np.ndarray) -> Tuple[np.ndarray]:
88+
def calc_model_devi_v(
89+
vs: np.ndarray, real_v: Optional[np.ndarray] = None
90+
) -> Tuple[np.ndarray]:
7191
"""Calculate model deviation of virial.
7292
7393
Parameters
7494
----------
7595
vs : numpy.ndarray
7696
size of `n_models x n_frames x 9`
97+
real_v : numpy.ndarray
98+
real virial, size of `n_frames x 9`. If given,
99+
the RMS real error is calculated instead.
77100
78101
Returns
79102
-------
@@ -84,7 +107,10 @@ def calc_model_devi_v(vs: np.ndarray) -> Tuple[np.ndarray]:
84107
avg_devi_v : numpy.ndarray
85108
average deviation of virial in 9 elements
86109
"""
87-
vs_devi = np.std(vs, axis=0)
110+
if real_v is None:
111+
vs_devi = np.std(vs, axis=0)
112+
else:
113+
vs_devi = np.sqrt(np.mean(np.square(vs - real_v), axis=0))
88114
max_devi_v = np.max(vs_devi, axis=-1)
89115
min_devi_v = np.min(vs_devi, axis=-1)
90116
avg_devi_v = np.linalg.norm(vs_devi, axis=-1) / 3
@@ -148,6 +174,7 @@ def calc_model_devi(
148174
mixed_type=False,
149175
fparam: Optional[np.ndarray] = None,
150176
aparam: Optional[np.ndarray] = None,
177+
real_data: Optional[dict] = None,
151178
):
152179
"""Python interface to calculate model deviation.
153180
@@ -171,6 +198,8 @@ def calc_model_devi(
171198
frame specific parameters
172199
aparam : numpy.ndarray
173200
atomic specific parameters
201+
real_data : dict, optional
202+
real data to calculate RMS real error
174203
175204
Returns
176205
-------
@@ -211,17 +240,29 @@ def calc_model_devi(
211240
virials = np.array(virials)
212241

213242
devi = [np.arange(coord.shape[0]) * frequency]
214-
devi += list(calc_model_devi_v(virials))
215-
devi += list(calc_model_devi_f(forces))
216-
devi.append(calc_model_devi_e(energies))
243+
if real_data is None:
244+
devi += list(calc_model_devi_v(virials))
245+
devi += list(calc_model_devi_f(forces))
246+
devi.append(calc_model_devi_e(energies))
247+
else:
248+
devi += list(calc_model_devi_v(virials, real_data["virial"]))
249+
devi += list(calc_model_devi_f(forces, real_data["force"]))
250+
devi.append(calc_model_devi_e(energies, real_data["energy"]))
217251
devi = np.vstack(devi).T
218252
if fname:
219253
write_model_devi_out(devi, fname)
220254
return devi
221255

222256

223257
def make_model_devi(
224-
*, models: list, system: str, set_prefix: str, output: str, frequency: int, **kwargs
258+
*,
259+
models: list,
260+
system: str,
261+
set_prefix: str,
262+
output: str,
263+
frequency: int,
264+
real_error: bool = False,
265+
**kwargs,
225266
):
226267
"""Make model deviation calculation.
227268
@@ -239,6 +280,8 @@ def make_model_devi(
239280
The number of steps that elapse between writing coordinates
240281
in a trajectory by a MD engine (such as Gromacs / Lammps).
241282
This paramter is used to determine the index in the output file.
283+
real_error : bool, default: False
284+
If True, calculate the RMS real error instead of model deviation.
242285
**kwargs
243286
Arbitrary keyword arguments.
244287
"""
@@ -279,6 +322,29 @@ def make_model_devi(
279322
must=True,
280323
high_prec=False,
281324
)
325+
if real_error:
326+
dp_data.add(
327+
"energy",
328+
1,
329+
atomic=False,
330+
must=False,
331+
high_prec=True,
332+
)
333+
dp_data.add(
334+
"force",
335+
3,
336+
atomic=True,
337+
must=False,
338+
high_prec=False,
339+
)
340+
dp_data.add(
341+
"virial",
342+
9,
343+
atomic=False,
344+
must=False,
345+
high_prec=False,
346+
)
347+
282348
mixed_type = dp_data.mixed_type
283349

284350
data_sets = [dp_data._load_set(set_name) for set_name in dp_data.dirs]
@@ -301,6 +367,15 @@ def make_model_devi(
301367
aparam = data["aparam"]
302368
else:
303369
aparam = None
370+
if real_error:
371+
natoms = atype.shape[-1]
372+
real_data = {
373+
"energy": data["energy"] / natoms,
374+
"force": data["force"].reshape([-1, natoms, 3]),
375+
"virial": data["virial"] / natoms,
376+
}
377+
else:
378+
real_data = None
304379
devi = calc_model_devi(
305380
coord,
306381
box,
@@ -309,6 +384,7 @@ def make_model_devi(
309384
mixed_type=mixed_type,
310385
fparam=fparam,
311386
aparam=aparam,
387+
real_data=real_data,
312388
)
313389
nframes_tot += coord.shape[0]
314390
devis.append(devi)

deepmd_cli/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,12 @@ def main_parser() -> argparse.ArgumentParser:
438438
type=int,
439439
help="The trajectory frequency of the system",
440440
)
441+
parser_model_devi.add_argument(
442+
"--real_error",
443+
action="store_true",
444+
default=False,
445+
help="Calculate the RMS real error of the model. The real data should be given in the systems.",
446+
)
441447

442448
# * convert models
443449
parser_transform = subparsers.add_parser(

source/tests/test_model_devi.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,33 @@ def test_make_model_devi(self):
8686
x = np.loadtxt(self.output)
8787
np.testing.assert_allclose(x, self.expect, 6)
8888

89+
def test_make_model_devi_real_erorr(self):
90+
make_model_devi(
91+
models=self.graph_dirs,
92+
system=self.data_dir,
93+
set_prefix="set",
94+
output=self.output,
95+
frequency=self.freq,
96+
real_error=True,
97+
)
98+
x = np.loadtxt(self.output)
99+
np.testing.assert_allclose(
100+
x,
101+
np.array(
102+
[
103+
0.000000e00,
104+
6.709021e-01,
105+
1.634359e-03,
106+
3.219720e-01,
107+
2.018684e00,
108+
1.829748e00,
109+
1.956474e00,
110+
1.550898e02,
111+
]
112+
),
113+
6,
114+
)
115+
89116
def tearDown(self):
90117
for pb in self.graph_dirs:
91118
os.remove(pb)

0 commit comments

Comments
 (0)