Skip to content

Commit ef0797e

Browse files
authored
add correction method to MultiSystems (#315)
1 parent 744bf3b commit ef0797e

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

dpdata/system.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,45 @@ def pick_atom_idx(self, idx, nopbc=None):
12941294
new_sys.append(ss.pick_atom_idx(idx, nopbc=nopbc))
12951295
return new_sys
12961296

1297+
def correction(self, hl_sys: "MultiSystems"):
1298+
"""Get energy and force correction between self (assumed low-level) and a high-level MultiSystems.
1299+
The self's coordinates will be kept, but energy and forces will be replaced by
1300+
the correction between these two systems.
1301+
1302+
Notes
1303+
-----
1304+
This method will not check whether coordinates and elements of two systems
1305+
are the same. The user should make sure by itself.
1306+
1307+
Parameters
1308+
----------
1309+
hl_sys : MultiSystems
1310+
high-level MultiSystems
1311+
1312+
Returns
1313+
-------
1314+
corrected_sys : MultiSystems
1315+
Corrected MultiSystems
1316+
1317+
Examples
1318+
--------
1319+
Get correction between a low-level system and a high-level system:
1320+
1321+
>>> low_level = dpdata.MultiSystems().from_deepmd_hdf5("low_level.hdf5")
1322+
>>> high_level = dpdata.MultiSystems().from_deepmd_hdf5("high_level.hdf5")
1323+
>>> corr = low_level.correction(high_lebel)
1324+
>>> corr.to_deepmd_hdf5("corr.hdf5")
1325+
"""
1326+
if not isinstance(hl_sys, MultiSystems):
1327+
raise RuntimeError("high_sys should be MultiSystems")
1328+
corrected_sys = MultiSystems(type_map=self.atom_names)
1329+
for nn in self.systems.keys():
1330+
ll_ss = self[nn]
1331+
hl_ss = hl_sys[nn]
1332+
corrected_sys.append(ll_ss.correction(hl_ss))
1333+
return corrected_sys
1334+
1335+
12971336
def get_cls_name(cls: object) -> str:
12981337
"""Returns the fully qualified name of a class, such as `np.ndarray`.
12991338

tests/test_corr.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,19 @@ def setUp(self):
2828
self.f_places = 6
2929
self.v_places = 6
3030

31+
32+
class TestCorr(unittest.TestCase, CompLabeledSys, IsPBC):
33+
"""Make a test to get a correction of two MultiSystems."""
34+
def setUp(self):
35+
s_ll = dpdata.MultiSystems(dpdata.LabeledSystem("amber/corr/dp_ll", fmt="deepmd/npy"))
36+
s_hl = dpdata.MultiSystems(dpdata.LabeledSystem("amber/corr/dp_hl", fmt="deepmd/npy"))
37+
self.system_1 = tuple(s_ll.correction(s_hl).systems.values())[0]
38+
self.system_2 = dpdata.LabeledSystem("amber/corr/dp_corr" ,fmt="deepmd/npy")
39+
self.places = 5
40+
self.e_places = 4
41+
self.f_places = 6
42+
self.v_places = 6
43+
44+
3145
if __name__ == '__main__':
3246
unittest.main()

0 commit comments

Comments
 (0)