@@ -1294,6 +1294,45 @@ def pick_atom_idx(self, idx, nopbc=None):
1294
1294
new_sys .append (ss .pick_atom_idx (idx , nopbc = nopbc ))
1295
1295
return new_sys
1296
1296
1297
+ def correction (self , hl_sys : "MultiSystems" ):
1298
+ """Get energy and force correction between self (assumed low-level) and a high-level MultiSystems.
1299
+ The self's coordinates will be kept, but energy and forces will be replaced by
1300
+ the correction between these two systems.
1301
+
1302
+ Notes
1303
+ -----
1304
+ This method will not check whether coordinates and elements of two systems
1305
+ are the same. The user should make sure by itself.
1306
+
1307
+ Parameters
1308
+ ----------
1309
+ hl_sys : MultiSystems
1310
+ high-level MultiSystems
1311
+
1312
+ Returns
1313
+ -------
1314
+ corrected_sys : MultiSystems
1315
+ Corrected MultiSystems
1316
+
1317
+ Examples
1318
+ --------
1319
+ Get correction between a low-level system and a high-level system:
1320
+
1321
+ >>> low_level = dpdata.MultiSystems().from_deepmd_hdf5("low_level.hdf5")
1322
+ >>> high_level = dpdata.MultiSystems().from_deepmd_hdf5("high_level.hdf5")
1323
+ >>> corr = low_level.correction(high_lebel)
1324
+ >>> corr.to_deepmd_hdf5("corr.hdf5")
1325
+ """
1326
+ if not isinstance (hl_sys , MultiSystems ):
1327
+ raise RuntimeError ("high_sys should be MultiSystems" )
1328
+ corrected_sys = MultiSystems (type_map = self .atom_names )
1329
+ for nn in self .systems .keys ():
1330
+ ll_ss = self [nn ]
1331
+ hl_ss = hl_sys [nn ]
1332
+ corrected_sys .append (ll_ss .correction (hl_ss ))
1333
+ return corrected_sys
1334
+
1335
+
1297
1336
def get_cls_name (cls : object ) -> str :
1298
1337
"""Returns the fully qualified name of a class, such as `np.ndarray`.
1299
1338
0 commit comments