28
28
from monty .serialization import loadfn ,dumpfn
29
29
from dpdata .periodic_table import Element
30
30
from dpdata .xyz .quip_gap_xyz import QuipGapxyzSystems
31
+ from dpdata .amber .mask import pick_by_amber_mask , load_param_file
31
32
32
33
33
34
class Register :
@@ -927,6 +928,10 @@ def nopbc(self):
927
928
return True
928
929
return False
929
930
931
+ @nopbc .setter
932
+ def nopbc (self , value ):
933
+ self .data ['nopbc' ] = value
934
+
930
935
def shuffle (self ):
931
936
"""Shuffle frames randomly."""
932
937
idx = np .random .permutation (self .get_nframes ())
@@ -973,6 +978,93 @@ def predict(self, dp):
973
978
labeled_sys .append (this_sys )
974
979
return labeled_sys
975
980
981
+ def pick_atom_idx (self , idx , nopbc = None ):
982
+ """Pick atom index
983
+
984
+ Parameters
985
+ ----------
986
+ idx: int or list or slice
987
+ atom index
988
+ nopbc: Boolen (default: None)
989
+ If nopbc is True or False, set nopbc
990
+
991
+ Returns
992
+ -------
993
+ new_sys: System
994
+ new system
995
+ """
996
+ new_sys = self .copy ()
997
+ new_sys .data ['coords' ] = self .data ['coords' ][:, idx , :]
998
+ new_sys .data ['atom_types' ] = self .data ['atom_types' ][idx ]
999
+ # recalculate atom_numbs according to atom_types
1000
+ atom_numbs = np .bincount (new_sys .data ['atom_types' ], minlength = len (self .get_atom_names ()))
1001
+ new_sys .data ['atom_numbs' ] = list (atom_numbs )
1002
+ if nopbc is True or nopbc is False :
1003
+ new_sys .nopbc = nopbc
1004
+ return new_sys
1005
+
1006
+ def remove_atom_names (self , atom_names ):
1007
+ """Remove atom names and all such atoms.
1008
+ For example, you may not remove EP atoms in TIP4P/Ew water, which
1009
+ is not a real atom.
1010
+ """
1011
+ if isinstance (atom_names , str ):
1012
+ atom_names = [atom_names ]
1013
+ removed_atom_idx = []
1014
+ for an in atom_names :
1015
+ # get atom name idx
1016
+ idx = self .data ['atom_names' ].index (an )
1017
+ atom_idx = self .data ['atom_types' ] == idx
1018
+ removed_atom_idx .append (atom_idx )
1019
+ picked_atom_idx = ~ np .any (removed_atom_idx , axis = 0 )
1020
+ new_sys = self .pick_atom_idx (picked_atom_idx )
1021
+ # let's remove atom_names
1022
+ # firstly, rearrange atom_names and put these atom_names in the end
1023
+ new_atom_names = list ([xx for xx in new_sys .data ['atom_names' ] if xx not in atom_names ])
1024
+ new_sys .sort_atom_names (type_map = new_atom_names + atom_names )
1025
+ # remove atom_names and atom_numbs
1026
+ new_sys .data ['atom_names' ] = new_atom_names
1027
+ new_sys .data ['atom_numbs' ] = new_sys .data ['atom_numbs' ][:len (new_atom_names )]
1028
+ return new_sys
1029
+
1030
+ def pick_by_amber_mask (self , param , maskstr , pass_coords = False , nopbc = None ):
1031
+ """Pick atoms by amber mask
1032
+
1033
+ Parameters
1034
+ ----------
1035
+ param: str or parmed.Structure
1036
+ filename of Amber param file or parmed.Structure
1037
+ maskstr: str
1038
+ Amber masks
1039
+ pass_coords: Boolen (default: False)
1040
+ If pass_coords is true, the function will pass coordinates and
1041
+ return a MultiSystem. Otherwise, the result is
1042
+ coordinate-independent, and the function will return System or
1043
+ LabeledSystem.
1044
+ nopbc: Boolen (default: None)
1045
+ If nopbc is True or False, set nopbc
1046
+ """
1047
+ parm = load_param_file (param )
1048
+ if pass_coords :
1049
+ ms = MultiSystems ()
1050
+ for sub_s in self :
1051
+ # TODO: this can computed in pararrel
1052
+ idx = pick_by_amber_mask (parm , maskstr , sub_s ['coords' ][0 ])
1053
+ ms .append (sub_s .pick_atom_idx (idx , nopbc = nopbc ))
1054
+ return ms
1055
+ else :
1056
+ idx = pick_by_amber_mask (parm , maskstr )
1057
+ return self .pick_atom_idx (idx , nopbc = nopbc )
1058
+
1059
+ @register_from_funcs .register_funcs ('amber/md' )
1060
+ def from_amber_md (self , file_name = None , parm7_file = None , nc_file = None , use_element_symbols = None ):
1061
+ # assume the prefix is the same if the spefic name is not given
1062
+ if parm7_file is None :
1063
+ parm7_file = file_name + ".parm7"
1064
+ if nc_file is None :
1065
+ nc_file = file_name + ".nc"
1066
+ self .data = dpdata .amber .md .read_amber_traj (parm7_file = parm7_file , nc_file = nc_file , use_element_symbols = use_element_symbols , labeled = False )
1067
+
976
1068
def get_cell_perturb_matrix (cell_pert_fraction ):
977
1069
if cell_pert_fraction < 0 :
978
1070
raise RuntimeError ('cell_pert_fraction can not be negative' )
@@ -1305,7 +1397,7 @@ def from_gaussian_md(self, file_name):
1305
1397
self .from_gaussian_log (file_name , md = True )
1306
1398
1307
1399
@register_from_funcs .register_funcs ('amber/md' )
1308
- def from_amber_md (self , file_name = None , parm7_file = None , nc_file = None , mdfrc_file = None , mden_file = None ):
1400
+ def from_amber_md (self , file_name = None , parm7_file = None , nc_file = None , mdfrc_file = None , mden_file = None , mdout_file = None , use_element_symbols = None ):
1309
1401
# assume the prefix is the same if the spefic name is not given
1310
1402
if parm7_file is None :
1311
1403
parm7_file = file_name + ".parm7"
@@ -1315,7 +1407,9 @@ def from_amber_md(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_fil
1315
1407
mdfrc_file = file_name + ".mdfrc"
1316
1408
if mden_file is None :
1317
1409
mden_file = file_name + ".mden"
1318
- self .data = dpdata .amber .md .read_amber_traj (parm7_file , nc_file , mdfrc_file , mden_file )
1410
+ if mdout_file is None :
1411
+ mdout_file = file_name + ".mdout"
1412
+ self .data = dpdata .amber .md .read_amber_traj (parm7_file , nc_file , mdfrc_file , mden_file , mdout_file , use_element_symbols )
1319
1413
1320
1414
@register_from_funcs .register_funcs ('cp2k/output' )
1321
1415
def from_cp2k_output (self , file_name ) :
@@ -1475,6 +1569,53 @@ def to_pymatgen_ComputedStructureEntry(self):
1475
1569
entries .append (entry )
1476
1570
return entries
1477
1571
1572
+ def correction (self , hl_sys ):
1573
+ """Get energy and force correction between self and a high-level LabeledSystem.
1574
+ The self's coordinates will be kept, but energy and forces will be replaced by
1575
+ the correction between these two systems.
1576
+
1577
+ Note: The function will not check whether coordinates and elements of two systems
1578
+ are the same. The user should make sure by itself.
1579
+
1580
+ Parameters
1581
+ ----------
1582
+ hl_sys: LabeledSystem
1583
+ high-level LabeledSystem
1584
+ Returns
1585
+ ----------
1586
+ corrected_sys: LabeledSystem
1587
+ Corrected LabeledSystem
1588
+ """
1589
+ if not isinstance (hl_sys , LabeledSystem ):
1590
+ raise RuntimeError ("high_sys should be LabeledSystem" )
1591
+ corrected_sys = self .copy ()
1592
+ corrected_sys .data ['energies' ] = hl_sys .data ['energies' ] - self .data ['energies' ]
1593
+ corrected_sys .data ['forces' ] = hl_sys .data ['forces' ] - self .data ['forces' ]
1594
+ if 'virials' in self .data and 'virials' in hl_sys .data :
1595
+ corrected_sys .data ['virials' ] = hl_sys .data ['virials' ] - self .data ['virials' ]
1596
+ return corrected_sys
1597
+
1598
+ def pick_atom_idx (self , idx , nopbc = None ):
1599
+ """Pick atom index
1600
+
1601
+ Parameters
1602
+ ----------
1603
+ idx: int or list or slice
1604
+ atom index
1605
+ nopbc: Boolen (default: None)
1606
+ If nopbc is True or False, set nopbc
1607
+
1608
+ Returns
1609
+ -------
1610
+ new_sys: LabeledSystem
1611
+ new system
1612
+ """
1613
+ new_sys = System .pick_atom_idx (self , idx , nopbc = nopbc )
1614
+ # forces
1615
+ new_sys .data ['forces' ] = self .data ['forces' ][:, idx , :]
1616
+ return new_sys
1617
+
1618
+
1478
1619
class MultiSystems :
1479
1620
'''A set containing several systems.'''
1480
1621
@@ -1650,6 +1791,26 @@ def predict(self, dp):
1650
1791
for ss in self :
1651
1792
new_multisystems .append (ss .predict (dp ))
1652
1793
return new_multisystems
1794
+
1795
+ def pick_atom_idx (self , idx , nopbc = None ):
1796
+ """Pick atom index
1797
+
1798
+ Parameters
1799
+ ----------
1800
+ idx: int or list or slice
1801
+ atom index
1802
+ nopbc: Boolen (default: None)
1803
+ If nopbc is True or False, set nopbc
1804
+
1805
+ Returns
1806
+ -------
1807
+ new_sys: MultiSystems
1808
+ new system
1809
+ """
1810
+ new_sys = MultiSystems ()
1811
+ for ss in self :
1812
+ new_sys .append (ss .pick_atom_idx (idx , nopbc = nopbc ))
1813
+ return new_sys
1653
1814
1654
1815
1655
1816
def check_System (data ):
0 commit comments