Skip to content

Commit fa6a213

Browse files
njzjzamcadmus
andauthored
enhance and add several Amber-related functions (#135)
* add some useful functions: pick_atom_idx, correction * pick_atom_idx for MultiSystems * fix bugs * move all constants to one file * load amber energy from mdout file if mden file does not exist * add remove_atom_names function; bugfix * add nopbc.setter * use_element_symbols for amber * pick_by_amber_mask * fix bugs * make constants more accurate * fix tests * fix bug * System from amber * Revert "fix tests" This reverts commit 9f65fac. * Revert "make constants more accurate" This reverts commit 43947eb. * Revert "move all constants to one file" This reverts commit 5a85a78. * add tests and fix bugs * add extras_require * split unittests Co-authored-by: Han Wang <[email protected]>
1 parent e979419 commit fa6a213

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+33133
-33
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
with:
1919
python-version: ${{ matrix.python-version }}
2020
- name: Install dependencies
21-
run: pip install . coverage codecov
21+
run: pip install .[amber] coverage codecov
2222
- name: Test
2323
run: cd tests && coverage run --source=../dpdata -m unittest && cd .. && coverage combine tests/.coverage && coverage report
2424
- run: codecov

dpdata/amber/mask.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Amber mask"""
2+
try:
3+
import parmed
4+
except ImportError:
5+
pass
6+
7+
def pick_by_amber_mask(param, maskstr, coords=None):
8+
"""Pick atoms by amber masks
9+
10+
Parameters
11+
----------
12+
param: str or parmed.Structure
13+
filename of Amber param file or parmed.Structure
14+
maskstr: str
15+
Amber masks
16+
coords: np.ndarray (optional)
17+
frame coordinates, shape: N*3
18+
"""
19+
parm = load_param_file(param)
20+
if coords is not None:
21+
parm.initialize_topology(xyz=coords)
22+
sele = []
23+
if len(maskstr) > 0:
24+
newmaskstr = maskstr.replace("@0", "!@*")
25+
sele = [parm.atoms[i].idx for i in parmed.amber.mask.AmberMask(
26+
parm, newmaskstr).Selected()]
27+
return sele
28+
29+
def load_param_file(param_file):
30+
if isinstance(param_file, str):
31+
return parmed.load_file(param_file)
32+
elif isinstance(param_file, parmed.Structure):
33+
return param_file
34+
else:
35+
raise RuntimeError("Unsupported structure")

dpdata/amber/md.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,47 @@
11
import re
2+
import os
23
from scipy.io import netcdf
34
import numpy as np
5+
from dpdata.amber.mask import pick_by_amber_mask
46

57
kcalmol2eV= 0.04336410390059322
8+
symbols = ['X', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']
69

710
energy_convert = kcalmol2eV
811
force_convert = energy_convert
912

1013

11-
def read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file):
14+
def read_amber_traj(parm7_file, nc_file, mdfrc_file=None, mden_file = None, mdout_file = None,
15+
use_element_symbols=None, labeled=True,
16+
):
1217
"""The amber trajectory includes:
1318
* nc, NetCDF format, stores coordinates
1419
* mdfrc, NetCDF format, stores forces
15-
* mden, text format, stores energies
20+
* mden (optional), text format, stores energies
21+
* mdout (optional), text format, may store energies if there is no mden_file
1622
* parm7, text format, stores types
23+
24+
Parameters
25+
----------
26+
parm7_file, nc_file, mdfrc_file, mden_file, mdout_file:
27+
filenames
28+
use_element_symbols: None or list or str
29+
If use_element_symbols is a list of atom indexes, these atoms will use element symbols
30+
instead of amber types. For example, a ligand will use C, H, O, N, and so on
31+
instead of h1, hc, o, os, and so on.
32+
IF use_element_symbols is str, it will be considered as Amber mask.
1733
"""
1834

19-
flag=False
35+
flag_atom_type = False
36+
flag_atom_numb = False
2037
amber_types = []
38+
atomic_number = []
2139
with open(parm7_file) as f:
2240
for line in f:
2341
if line.startswith("%FLAG"):
24-
flag = line.startswith("%FLAG AMBER_ATOM_TYPE")
25-
elif flag:
42+
flag_atom_type = line.startswith("%FLAG AMBER_ATOM_TYPE")
43+
flag_atom_numb = (use_element_symbols is not None) and line.startswith("%FLAG ATOMIC_NUMBER")
44+
elif flag_atom_type or flag_atom_numb:
2645
if line.startswith("%FORMAT"):
2746
fmt = re.findall(r'\d+', line)
2847
fmt0 = int(fmt[0])
@@ -33,7 +52,16 @@ def read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file):
3352
end_index = (ii + 1) * fmt1
3453
if end_index >= len(line):
3554
continue
36-
amber_types.append(line[start_index:end_index].strip())
55+
content = line[start_index:end_index].strip()
56+
if flag_atom_type:
57+
amber_types.append(content)
58+
elif flag_atom_numb:
59+
atomic_number.append(int(content))
60+
if use_element_symbols is not None:
61+
if isinstance(use_element_symbols, str):
62+
use_element_symbols = pick_by_amber_mask(parm7_file, use_element_symbols)
63+
for ii in use_element_symbols:
64+
amber_types[ii] = symbols[atomic_number[ii]]
3765

3866
with netcdf.netcdf_file(nc_file, 'r') as f:
3967
coords = np.array(f.variables["coordinates"][:])
@@ -49,26 +77,37 @@ def read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file):
4977
else:
5078
raise RuntimeError("Unsupported cells")
5179

52-
with netcdf.netcdf_file(mdfrc_file, 'r') as f:
53-
forces = np.array(f.variables["forces"][:])
80+
if labeled:
81+
with netcdf.netcdf_file(mdfrc_file, 'r') as f:
82+
forces = np.array(f.variables["forces"][:])
5483

55-
# energy
56-
energies = []
57-
with open(mden_file) as f:
58-
for line in f:
59-
if line.startswith("L6"):
60-
s = line.split()
61-
if s[2] != "E_pot":
62-
energies.append(float(s[2]))
84+
# load energy from mden_file or mdout_file
85+
energies = []
86+
if mden_file is not None and os.path.isfile(mden_file):
87+
with open(mden_file) as f:
88+
for line in f:
89+
if line.startswith("L6"):
90+
s = line.split()
91+
if s[2] != "E_pot":
92+
energies.append(float(s[2]))
93+
elif mdout_file is not None and os.path.isfile(mdout_file):
94+
with open(mdout_file) as f:
95+
for line in f:
96+
if "EPtot" in line:
97+
s = line.split()
98+
energies.append(float(s[-1]))
99+
else:
100+
raise RuntimeError("Please provide one of mden_file and mdout_file")
63101

64102
atom_names, atom_types, atom_numbs = np.unique(amber_types, return_inverse=True, return_counts=True)
65103

66104
data = {}
67105
data['atom_names'] = list(atom_names)
68106
data['atom_numbs'] = list(atom_numbs)
69107
data['atom_types'] = atom_types
70-
data['forces'] = forces * force_convert
71-
data['energies'] = np.array(energies) * energy_convert
108+
if labeled:
109+
data['forces'] = forces * force_convert
110+
data['energies'] = np.array(energies) * energy_convert
72111
data['coords'] = coords
73112
data['cells'] = cells
74113
data['orig'] = np.array([0, 0, 0])

dpdata/system.py

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from monty.serialization import loadfn,dumpfn
2929
from dpdata.periodic_table import Element
3030
from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems
31+
from dpdata.amber.mask import pick_by_amber_mask, load_param_file
3132

3233

3334
class Register:
@@ -927,6 +928,10 @@ def nopbc(self):
927928
return True
928929
return False
929930

931+
@nopbc.setter
932+
def nopbc(self, value):
933+
self.data['nopbc'] = value
934+
930935
def shuffle(self):
931936
"""Shuffle frames randomly."""
932937
idx = np.random.permutation(self.get_nframes())
@@ -973,6 +978,93 @@ def predict(self, dp):
973978
labeled_sys.append(this_sys)
974979
return labeled_sys
975980

981+
def pick_atom_idx(self, idx, nopbc=None):
982+
"""Pick atom index
983+
984+
Parameters
985+
----------
986+
idx: int or list or slice
987+
atom index
988+
nopbc: Boolen (default: None)
989+
If nopbc is True or False, set nopbc
990+
991+
Returns
992+
-------
993+
new_sys: System
994+
new system
995+
"""
996+
new_sys = self.copy()
997+
new_sys.data['coords'] = self.data['coords'][:, idx, :]
998+
new_sys.data['atom_types'] = self.data['atom_types'][idx]
999+
# recalculate atom_numbs according to atom_types
1000+
atom_numbs = np.bincount(new_sys.data['atom_types'], minlength=len(self.get_atom_names()))
1001+
new_sys.data['atom_numbs'] = list(atom_numbs)
1002+
if nopbc is True or nopbc is False:
1003+
new_sys.nopbc = nopbc
1004+
return new_sys
1005+
1006+
def remove_atom_names(self, atom_names):
1007+
"""Remove atom names and all such atoms.
1008+
For example, you may not remove EP atoms in TIP4P/Ew water, which
1009+
is not a real atom.
1010+
"""
1011+
if isinstance(atom_names, str):
1012+
atom_names = [atom_names]
1013+
removed_atom_idx = []
1014+
for an in atom_names:
1015+
# get atom name idx
1016+
idx = self.data['atom_names'].index(an)
1017+
atom_idx = self.data['atom_types'] == idx
1018+
removed_atom_idx.append(atom_idx)
1019+
picked_atom_idx = ~np.any(removed_atom_idx, axis=0)
1020+
new_sys = self.pick_atom_idx(picked_atom_idx)
1021+
# let's remove atom_names
1022+
# firstly, rearrange atom_names and put these atom_names in the end
1023+
new_atom_names = list([xx for xx in new_sys.data['atom_names'] if xx not in atom_names])
1024+
new_sys.sort_atom_names(type_map=new_atom_names + atom_names)
1025+
# remove atom_names and atom_numbs
1026+
new_sys.data['atom_names'] = new_atom_names
1027+
new_sys.data['atom_numbs'] = new_sys.data['atom_numbs'][:len(new_atom_names)]
1028+
return new_sys
1029+
1030+
def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None):
1031+
"""Pick atoms by amber mask
1032+
1033+
Parameters
1034+
----------
1035+
param: str or parmed.Structure
1036+
filename of Amber param file or parmed.Structure
1037+
maskstr: str
1038+
Amber masks
1039+
pass_coords: Boolen (default: False)
1040+
If pass_coords is true, the function will pass coordinates and
1041+
return a MultiSystem. Otherwise, the result is
1042+
coordinate-independent, and the function will return System or
1043+
LabeledSystem.
1044+
nopbc: Boolen (default: None)
1045+
If nopbc is True or False, set nopbc
1046+
"""
1047+
parm = load_param_file(param)
1048+
if pass_coords:
1049+
ms = MultiSystems()
1050+
for sub_s in self:
1051+
# TODO: this can computed in pararrel
1052+
idx = pick_by_amber_mask(parm, maskstr, sub_s['coords'][0])
1053+
ms.append(sub_s.pick_atom_idx(idx, nopbc=nopbc))
1054+
return ms
1055+
else:
1056+
idx = pick_by_amber_mask(parm, maskstr)
1057+
return self.pick_atom_idx(idx, nopbc=nopbc)
1058+
1059+
@register_from_funcs.register_funcs('amber/md')
1060+
def from_amber_md(self, file_name=None, parm7_file=None, nc_file=None, use_element_symbols=None):
1061+
# assume the prefix is the same if the spefic name is not given
1062+
if parm7_file is None:
1063+
parm7_file = file_name + ".parm7"
1064+
if nc_file is None:
1065+
nc_file = file_name + ".nc"
1066+
self.data = dpdata.amber.md.read_amber_traj(parm7_file=parm7_file, nc_file=nc_file, use_element_symbols=use_element_symbols, labeled=False)
1067+
9761068
def get_cell_perturb_matrix(cell_pert_fraction):
9771069
if cell_pert_fraction<0:
9781070
raise RuntimeError('cell_pert_fraction can not be negative')
@@ -1305,7 +1397,7 @@ def from_gaussian_md(self, file_name):
13051397
self.from_gaussian_log(file_name, md=True)
13061398

13071399
@register_from_funcs.register_funcs('amber/md')
1308-
def from_amber_md(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_file=None, mden_file=None):
1400+
def from_amber_md(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_file=None, mden_file=None, mdout_file=None, use_element_symbols=None):
13091401
# assume the prefix is the same if the spefic name is not given
13101402
if parm7_file is None:
13111403
parm7_file = file_name + ".parm7"
@@ -1315,7 +1407,9 @@ def from_amber_md(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_fil
13151407
mdfrc_file = file_name + ".mdfrc"
13161408
if mden_file is None:
13171409
mden_file = file_name + ".mden"
1318-
self.data = dpdata.amber.md.read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file)
1410+
if mdout_file is None:
1411+
mdout_file = file_name + ".mdout"
1412+
self.data = dpdata.amber.md.read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file, mdout_file, use_element_symbols)
13191413

13201414
@register_from_funcs.register_funcs('cp2k/output')
13211415
def from_cp2k_output(self, file_name) :
@@ -1475,6 +1569,53 @@ def to_pymatgen_ComputedStructureEntry(self):
14751569
entries.append(entry)
14761570
return entries
14771571

1572+
def correction(self, hl_sys):
1573+
"""Get energy and force correction between self and a high-level LabeledSystem.
1574+
The self's coordinates will be kept, but energy and forces will be replaced by
1575+
the correction between these two systems.
1576+
1577+
Note: The function will not check whether coordinates and elements of two systems
1578+
are the same. The user should make sure by itself.
1579+
1580+
Parameters
1581+
----------
1582+
hl_sys: LabeledSystem
1583+
high-level LabeledSystem
1584+
Returns
1585+
----------
1586+
corrected_sys: LabeledSystem
1587+
Corrected LabeledSystem
1588+
"""
1589+
if not isinstance(hl_sys, LabeledSystem):
1590+
raise RuntimeError("high_sys should be LabeledSystem")
1591+
corrected_sys = self.copy()
1592+
corrected_sys.data['energies'] = hl_sys.data['energies'] - self.data['energies']
1593+
corrected_sys.data['forces'] = hl_sys.data['forces'] - self.data['forces']
1594+
if 'virials' in self.data and 'virials' in hl_sys.data:
1595+
corrected_sys.data['virials'] = hl_sys.data['virials'] - self.data['virials']
1596+
return corrected_sys
1597+
1598+
def pick_atom_idx(self, idx, nopbc=None):
1599+
"""Pick atom index
1600+
1601+
Parameters
1602+
----------
1603+
idx: int or list or slice
1604+
atom index
1605+
nopbc: Boolen (default: None)
1606+
If nopbc is True or False, set nopbc
1607+
1608+
Returns
1609+
-------
1610+
new_sys: LabeledSystem
1611+
new system
1612+
"""
1613+
new_sys = System.pick_atom_idx(self, idx, nopbc=nopbc)
1614+
# forces
1615+
new_sys.data['forces'] = self.data['forces'][:, idx, :]
1616+
return new_sys
1617+
1618+
14781619
class MultiSystems:
14791620
'''A set containing several systems.'''
14801621

@@ -1650,6 +1791,26 @@ def predict(self, dp):
16501791
for ss in self:
16511792
new_multisystems.append(ss.predict(dp))
16521793
return new_multisystems
1794+
1795+
def pick_atom_idx(self, idx, nopbc=None):
1796+
"""Pick atom index
1797+
1798+
Parameters
1799+
----------
1800+
idx: int or list or slice
1801+
atom index
1802+
nopbc: Boolen (default: None)
1803+
If nopbc is True or False, set nopbc
1804+
1805+
Returns
1806+
-------
1807+
new_sys: MultiSystems
1808+
new system
1809+
"""
1810+
new_sys = MultiSystems()
1811+
for ss in self:
1812+
new_sys.append(ss.pick_atom_idx(idx, nopbc=nopbc))
1813+
return new_sys
16531814

16541815

16551816
def check_System(data):

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,9 @@
4646
],
4747
keywords='lammps vasp deepmd-kit',
4848
install_requires=install_requires,
49+
extras_require={
50+
'ase': ['ase'],
51+
'amber': ['parmed'],
52+
}
4953
)
5054

tests/amber/corr/dataset/C6H11HW192O6OW96P1/nopbc

Whitespace-only changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)