Skip to content

Commit 0b0d9a9

Browse files
authored
add stat module to compute MAE and RMSE (#316)
1 parent ef0797e commit 0b0d9a9

File tree

3 files changed

+185
-0
lines changed

3 files changed

+185
-0
lines changed

dpdata/stat.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from abc import ABCMeta, abstractproperty
2+
from functools import lru_cache
3+
4+
import numpy as np
5+
6+
from dpdata.system import LabeledSystem, MultiSystems
7+
8+
9+
def mae(errors: np.ndarray) -> np.float64:
10+
"""Compute the mean absolute error (MAE).
11+
12+
Parameters
13+
----------
14+
errors : np.ndarray
15+
errors between two values
16+
17+
Returns
18+
-------
19+
np.float64
20+
mean absolute error (MAE)
21+
"""
22+
return np.mean(np.abs(errors))
23+
24+
25+
def rmse(errors: np.ndarray) -> np.float64:
26+
"""Compute the root mean squared error (RMSE).
27+
28+
Parameters
29+
----------
30+
errors : np.ndarray
31+
errors between two values
32+
33+
Returns
34+
-------
35+
np.float64
36+
root mean squared error (RMSE)
37+
"""
38+
return np.sqrt(np.mean(np.square(errors)))
39+
40+
41+
class ErrorsBase(metaclass=ABCMeta):
42+
"""Compute errors (deviations) between two systems. The type of system is assigned by SYSTEM_TYPE.
43+
44+
Parameters
45+
----------
46+
system_1 : object
47+
system 1
48+
system_2 : object
49+
system 2
50+
"""
51+
SYSTEM_TYPE = object
52+
53+
def __init__(self, system_1: SYSTEM_TYPE, system_2: SYSTEM_TYPE) -> None:
54+
assert isinstance(system_1, self.SYSTEM_TYPE), "system_1 should be %s" % self.SYSTEM_TYPE.__name__
55+
assert isinstance(system_2, self.SYSTEM_TYPE), "system_2 should be %s" % self.SYSTEM_TYPE.__name__
56+
self.system_1 = system_1
57+
self.system_2 = system_2
58+
59+
@abstractproperty
60+
def e_errors(self) -> np.ndarray:
61+
"""Energy errors."""
62+
63+
@abstractproperty
64+
def f_errors(self) -> np.ndarray:
65+
"""Force errors."""
66+
67+
@property
68+
def e_mae(self) -> np.float64:
69+
"""Energy MAE."""
70+
return mae(self.e_errors)
71+
72+
@property
73+
def e_rmse(self) -> np.float64:
74+
"""Energy RMSE."""
75+
return rmse(self.e_errors)
76+
77+
@property
78+
def f_mae(self) -> np.float64:
79+
"""Force MAE."""
80+
return mae(self.f_errors)
81+
82+
@property
83+
def f_rmse(self) -> np.float64:
84+
"""Force RMSE."""
85+
return rmse(self.f_errors)
86+
87+
88+
class Errors(ErrorsBase):
89+
"""Compute errors (deviations) between two LabeledSystems.
90+
91+
Parameters
92+
----------
93+
system_1 : object
94+
system 1
95+
system_2 : object
96+
system 2
97+
98+
Examples
99+
--------
100+
Get errors between referenced system and predicted system:
101+
102+
>>> e = dpdata.errors.Errors(system_1, system_2)
103+
>>> print("%.4f %.4f %.4f %.4f" % (e.e_mae, e.e_rmse, e.f_mae, e.f_rmse))
104+
"""
105+
SYSTEM_TYPE = LabeledSystem
106+
107+
@property
108+
@lru_cache()
109+
def e_errors(self) -> np.ndarray:
110+
"""Energy errors."""
111+
return self.system_1['energies'] - self.system_2['energies']
112+
113+
@property
114+
@lru_cache()
115+
def f_errors(self) -> np.ndarray:
116+
"""Force errors."""
117+
return (self.system_1['forces'] - self.system_2['forces']).ravel()
118+
119+
120+
class MultiErrors(ErrorsBase):
121+
"""Compute errors (deviations) between two MultiSystems.
122+
123+
Parameters
124+
----------
125+
system_1 : object
126+
system 1
127+
system_2 : object
128+
system 2
129+
130+
Examples
131+
--------
132+
Get errors between referenced system and predicted system:
133+
134+
>>> e = dpdata.errors.MultiErrors(system_1, system_2)
135+
>>> print("%.4f %.4f %.4f %.4f" % (e.e_mae, e.e_rmse, e.f_mae, e.f_rmse))
136+
"""
137+
SYSTEM_TYPE = MultiSystems
138+
139+
@property
140+
@lru_cache()
141+
def e_errors(self) -> np.ndarray:
142+
"""Energy errors."""
143+
errors = []
144+
for nn in self.system_1.systems.keys():
145+
ss1 = self.system_1[nn]
146+
ss2 = self.system_2[nn]
147+
errors.append(Errors(ss1, ss2).e_errors.ravel())
148+
return np.concatenate(errors)
149+
150+
@property
151+
@lru_cache()
152+
def f_errors(self) -> np.ndarray:
153+
"""Force errors."""
154+
errors = []
155+
for nn in self.system_1.systems.keys():
156+
ss1 = self.system_1[nn]
157+
ss2 = self.system_2[nn]
158+
errors.append(Errors(ss1, ss2).f_errors.ravel())
159+
return np.concatenate(errors)

tests/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
import dpdata.md.msd
66
import dpdata.gaussian.gjf
77
import dpdata.system
8+
import dpdata.stat

tests/test_stat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from context import dpdata
2+
3+
import unittest
4+
5+
6+
class TestStat(unittest.TestCase):
7+
def test_errors(self):
8+
system1 = dpdata.LabeledSystem("gaussian/methane.gaussianlog", fmt="gaussian/log")
9+
system2 = dpdata.LabeledSystem("amber/sqm_opt.out", fmt="sqm/out")
10+
11+
e = dpdata.stat.Errors(system1, system2)
12+
self.assertAlmostEqual(e.e_mae, 1014.7946598792427, 6)
13+
self.assertAlmostEqual(e.e_rmse, 1014.7946598792427, 6)
14+
self.assertAlmostEqual(e.f_mae, 0.004113640526088011, 6)
15+
self.assertAlmostEqual(e.f_rmse, 0.005714011247538185, 6)
16+
17+
def test_multi_errors(self):
18+
system1 = dpdata.MultiSystems(dpdata.LabeledSystem("gaussian/methane.gaussianlog", fmt="gaussian/log"))
19+
system2 = dpdata.MultiSystems(dpdata.LabeledSystem("amber/sqm_opt.out", fmt="sqm/out"))
20+
21+
e = dpdata.stat.MultiErrors(system1, system2)
22+
self.assertAlmostEqual(e.e_mae, 1014.7946598792427, 6)
23+
self.assertAlmostEqual(e.e_rmse, 1014.7946598792427, 6)
24+
self.assertAlmostEqual(e.f_mae, 0.004113640526088011, 6)
25+
self.assertAlmostEqual(e.f_rmse, 0.005714011247538185, 6)

0 commit comments

Comments
 (0)