Skip to content

Commit d622c95

Browse files
Add support for 'deepmd/mixed' format with dpdata.MultiSystems (#422)
Add support for 'deepmd/mixed' format with dpdata.MultiSystems 1. Support dump from dpdata.MultiSystems to [mixed type format](https://github.com/deepmodeling/deepmd-kit/blob/master/doc/model/train-se-atten.md#data-format): dpdata.MultiSystems.to_deepmd_mixed('dir_name_mixed') 2. Support load from mixed type format to dpdata.MultiSystems: dpdata.MultiSystems.load_systems_from_file('dir_name_mixed', fmt='deepmd/mixed') --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 18508af commit d622c95

File tree

6 files changed

+581
-13
lines changed

6 files changed

+581
-13
lines changed

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ The `System` or `LabeledSystem` can be constructed from the following file forma
6262
| deepmd | npy | True | False | System | 'deepmd/npy' |
6363
| deepmd | raw | True | True | LabeledSystem | 'deepmd/raw' |
6464
| deepmd | npy | True | True | LabeledSystem | 'deepmd/npy' |
65+
| deepmd | npy | True | True | MultiSystems | 'deepmd/npy/mixed' |
66+
| deepmd | npy | True | False | MultiSystems | 'deepmd/npy/mixed' |
6567
| gaussian| log | False | True | LabeledSystem | 'gaussian/log'|
6668
| gaussian| log | True | True | LabeledSystem | 'gaussian/md' |
6769
| siesta | output | False | True | LabeledSystem | 'siesta/output'|
@@ -278,6 +280,30 @@ print(syst.get_charge()) # return the total charge of the system
278280

279281
If a valence of 3 is detected on carbon, the formal charge will be assigned to -1. Because for most cases (in alkynyl anion, isonitrile, cyclopentadienyl anion), the formal charge on 3-valence carbon is -1, and this is also consisent with the 8-electron rule.
280282

283+
## Mixed Type Format
284+
The format `deepmd/npy/mixed` is the mixed type numpy format for DeePMD-kit, and can be loaded or dumped through class `dpdata.MultiSystems`.
285+
286+
Under this format, systems with the same number of atoms but different formula can be put together
287+
for a larger system, especially when the frame numbers in systems are sparse.
288+
289+
This also helps to mixture the type information together for model training with type embedding network.
290+
291+
Here are examples using `deepmd/npy/mixed` format:
292+
293+
- Dump a MultiSystems into a mixed type numpy directory:
294+
```python
295+
import dpdata
296+
297+
dpdata.MultiSystems(*systems).to_deepmd_npy_mixed("mixed_dir")
298+
```
299+
300+
- Load a mixed type data into a MultiSystems:
301+
```python
302+
import dpdata
303+
304+
dpdata.MultiSystems().load_systems_from_file("mixed_dir", fmt="deepmd/npy/mixed")
305+
```
306+
281307
# Plugins
282308

283309
One can follow [a simple example](plugin_example/) to add their own format by creating and installing plugins. It's critical to add the [Format](dpdata/format.py) class to `entry_points['dpdata.plugins']` in [`pyproject.toml`](plugin_example/pyproject.toml):

dpdata/deepmd/mixed.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
import glob
2+
import os
3+
import shutil
4+
5+
import numpy as np
6+
7+
8+
def load_type(folder):
9+
data = {}
10+
data["atom_names"] = []
11+
# if find type_map.raw, use it
12+
assert os.path.isfile(
13+
os.path.join(folder, "type_map.raw")
14+
), "Mixed type system must have type_map.raw!"
15+
with open(os.path.join(folder, "type_map.raw")) as fp:
16+
data["atom_names"] = fp.read().split()
17+
18+
return data
19+
20+
21+
def formula(atom_names, atom_numbs):
22+
"""
23+
Return the formula of this system, like C3H5O2
24+
"""
25+
return "".join(
26+
["{}{}".format(symbol, numb) for symbol, numb in zip(atom_names, atom_numbs)]
27+
)
28+
29+
30+
def _cond_load_data(fname):
31+
tmp = None
32+
if os.path.isfile(fname):
33+
tmp = np.load(fname)
34+
return tmp
35+
36+
37+
def _load_set(folder, nopbc: bool):
38+
coords = np.load(os.path.join(folder, "coord.npy"))
39+
if nopbc:
40+
cells = np.zeros((coords.shape[0], 3, 3))
41+
else:
42+
cells = np.load(os.path.join(folder, "box.npy"))
43+
eners = _cond_load_data(os.path.join(folder, "energy.npy"))
44+
forces = _cond_load_data(os.path.join(folder, "force.npy"))
45+
virs = _cond_load_data(os.path.join(folder, "virial.npy"))
46+
real_atom_types = np.load(os.path.join(folder, "real_atom_types.npy"))
47+
return cells, coords, eners, forces, virs, real_atom_types
48+
49+
50+
def to_system_data(folder, type_map=None, labels=True):
51+
# data is empty
52+
data = load_type(folder)
53+
data["orig"] = np.zeros([3])
54+
if os.path.isfile(os.path.join(folder, "nopbc")):
55+
data["nopbc"] = True
56+
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
57+
assert len(sets) == 1, "Mixed type must have only one set!"
58+
cells, coords, eners, forces, virs, real_atom_types = _load_set(
59+
sets[0], data.get("nopbc", False)
60+
)
61+
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
62+
cells = np.reshape(cells, [nframes, 3, 3])
63+
coords = np.reshape(coords, [nframes, -1, 3])
64+
real_atom_types = np.reshape(real_atom_types, [nframes, -1])
65+
natom = real_atom_types.shape[1]
66+
if labels:
67+
if eners is not None and eners.size > 0:
68+
eners = np.reshape(eners, [nframes])
69+
if forces is not None and forces.size > 0:
70+
forces = np.reshape(forces, [nframes, -1, 3])
71+
if virs is not None and virs.size > 0:
72+
virs = np.reshape(virs, [nframes, 3, 3])
73+
data_list = []
74+
while True:
75+
if real_atom_types.size == 0:
76+
break
77+
temp_atom_numbs = [
78+
np.count_nonzero(real_atom_types[0] == i)
79+
for i in range(len(data["atom_names"]))
80+
]
81+
# temp_formula = formula(data['atom_names'], temp_atom_numbs)
82+
temp_idx = np.arange(real_atom_types.shape[0])[
83+
(real_atom_types == real_atom_types[0]).all(-1)
84+
]
85+
rest_idx = np.arange(real_atom_types.shape[0])[
86+
(real_atom_types != real_atom_types[0]).any(-1)
87+
]
88+
temp_data = data.copy()
89+
temp_data["atom_numbs"] = temp_atom_numbs
90+
temp_data["atom_types"] = real_atom_types[0]
91+
real_atom_types = real_atom_types[rest_idx]
92+
temp_data["cells"] = cells[temp_idx]
93+
cells = cells[rest_idx]
94+
temp_data["coords"] = coords[temp_idx]
95+
coords = coords[rest_idx]
96+
if labels:
97+
if eners is not None and eners.size > 0:
98+
temp_data["energies"] = eners[temp_idx]
99+
eners = eners[rest_idx]
100+
if forces is not None and forces.size > 0:
101+
temp_data["forces"] = forces[temp_idx]
102+
forces = forces[rest_idx]
103+
if virs is not None and virs.size > 0:
104+
temp_data["virials"] = virs[temp_idx]
105+
virs = virs[rest_idx]
106+
data_list.append(temp_data)
107+
return data_list
108+
109+
110+
def dump(folder, data, comp_prec=np.float32, remove_sets=True):
111+
os.makedirs(folder, exist_ok=True)
112+
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
113+
if len(sets) > 0:
114+
if remove_sets:
115+
for ii in sets:
116+
shutil.rmtree(ii)
117+
else:
118+
raise RuntimeError(
119+
"found "
120+
+ str(sets)
121+
+ " in "
122+
+ folder
123+
+ "not a clean deepmd raw dir. please firstly clean set.* then try compress"
124+
)
125+
# if not converted to mixed
126+
if "real_atom_types" not in data:
127+
from dpdata import LabeledSystem, System
128+
129+
if "energies" in data:
130+
temp_sys = LabeledSystem(data=data)
131+
else:
132+
temp_sys = System(data=data)
133+
temp_sys.convert_to_mixed_type()
134+
# dump raw
135+
np.savetxt(os.path.join(folder, "type.raw"), data["atom_types"], fmt="%d")
136+
np.savetxt(os.path.join(folder, "type_map.raw"), data["real_atom_names"], fmt="%s")
137+
# BondOrder System
138+
if "bonds" in data:
139+
np.savetxt(
140+
os.path.join(folder, "bonds.raw"),
141+
data["bonds"],
142+
header="begin_atom, end_atom, bond_order",
143+
)
144+
if "formal_charges" in data:
145+
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
146+
# reshape frame properties and convert prec
147+
nframes = data["cells"].shape[0]
148+
cells = np.reshape(data["cells"], [nframes, 9]).astype(comp_prec)
149+
coords = np.reshape(data["coords"], [nframes, -1]).astype(comp_prec)
150+
eners = None
151+
forces = None
152+
virials = None
153+
real_atom_types = None
154+
if "energies" in data:
155+
eners = np.reshape(data["energies"], [nframes]).astype(comp_prec)
156+
if "forces" in data:
157+
forces = np.reshape(data["forces"], [nframes, -1]).astype(comp_prec)
158+
if "virials" in data:
159+
virials = np.reshape(data["virials"], [nframes, 9]).astype(comp_prec)
160+
if "atom_pref" in data:
161+
atom_pref = np.reshape(data["atom_pref"], [nframes, -1]).astype(comp_prec)
162+
if "real_atom_types" in data:
163+
real_atom_types = np.reshape(data["real_atom_types"], [nframes, -1]).astype(
164+
np.int64
165+
)
166+
# dump frame properties: cell, coord, energy, force and virial
167+
set_folder = os.path.join(folder, "set.%03d" % 0)
168+
os.makedirs(set_folder)
169+
np.save(os.path.join(set_folder, "box"), cells)
170+
np.save(os.path.join(set_folder, "coord"), coords)
171+
if eners is not None:
172+
np.save(os.path.join(set_folder, "energy"), eners)
173+
if forces is not None:
174+
np.save(os.path.join(set_folder, "force"), forces)
175+
if virials is not None:
176+
np.save(os.path.join(set_folder, "virial"), virials)
177+
if real_atom_types is not None:
178+
np.save(os.path.join(set_folder, "real_atom_types"), real_atom_types)
179+
if "atom_pref" in data:
180+
np.save(os.path.join(set_folder, "atom_pref"), atom_pref)
181+
try:
182+
os.remove(os.path.join(folder, "nopbc"))
183+
except OSError:
184+
pass
185+
if data.get("nopbc", False):
186+
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
187+
pass
188+
189+
190+
def mix_system(*system, type_map, split_num=200, **kwargs):
191+
"""Mix the systems into mixed_type ones
192+
193+
Parameters
194+
----------
195+
*system : System
196+
The systems to mix
197+
type_map : list of str
198+
Maps atom type to name
199+
split_num : int
200+
Number of frames in each system
201+
202+
Returns
203+
-------
204+
mixed_systems: dict
205+
dict of mixed system with key '{atom_numbs}/sys.xxx'
206+
"""
207+
mixed_systems = {}
208+
temp_systems = {}
209+
atom_numbs_sys_index = {} # index of sys
210+
atom_numbs_frame_index = {} # index of frames in cur sys
211+
for sys in system:
212+
tmp_sys = sys.copy()
213+
natom = tmp_sys.get_natoms()
214+
tmp_sys.convert_to_mixed_type(type_map=type_map)
215+
if str(natom) not in atom_numbs_sys_index:
216+
atom_numbs_sys_index[str(natom)] = 0
217+
if str(natom) not in atom_numbs_frame_index:
218+
atom_numbs_frame_index[str(natom)] = 0
219+
atom_numbs_frame_index[str(natom)] += tmp_sys.get_nframes()
220+
if str(natom) not in temp_systems or not temp_systems[str(natom)]:
221+
temp_systems[str(natom)] = tmp_sys
222+
else:
223+
temp_systems[str(natom)].append(tmp_sys)
224+
if atom_numbs_frame_index[str(natom)] >= split_num:
225+
while True:
226+
sys_split, temp_systems[str(natom)], rest_num = split_system(
227+
temp_systems[str(natom)], split_num=split_num
228+
)
229+
sys_name = (
230+
f"{str(natom)}/sys." + "%.6d" % atom_numbs_sys_index[str(natom)]
231+
)
232+
mixed_systems[sys_name] = sys_split
233+
atom_numbs_sys_index[str(natom)] += 1
234+
if rest_num < split_num:
235+
atom_numbs_frame_index[str(natom)] = rest_num
236+
break
237+
for natom in temp_systems:
238+
if atom_numbs_frame_index[natom] > 0:
239+
sys_name = f"{natom}/sys." + "%.6d" % atom_numbs_sys_index[natom]
240+
mixed_systems[sys_name] = temp_systems[natom]
241+
return mixed_systems
242+
243+
244+
def split_system(sys, split_num=100):
245+
rest = sys.get_nframes() - split_num
246+
if rest <= 0:
247+
return sys, None, 0
248+
else:
249+
split_sys = sys.sub_system(range(split_num))
250+
rest_sys = sys.sub_system(range(split_num, sys.get_nframes()))
251+
return split_sys, rest_sys, rest

dpdata/format.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,24 @@ def to_multi_systems(self, formulas, directory, **kwargs):
131131
raise NotImplementedError(
132132
"%s doesn't support MultiSystems.to" % (self.__class__.__name__)
133133
)
134+
135+
def mix_system(self, *system, type_map, split_num=200, **kwargs):
136+
"""Mix the systems into mixed_type ones according to the unified given type_map.
137+
138+
Parameters
139+
----------
140+
*system : System
141+
The systems to mix
142+
type_map : list of str
143+
Maps atom type to name
144+
split_num : int
145+
Number of frames in each system
146+
147+
Returns
148+
-------
149+
mixed_systems: dict
150+
dict of mixed system with key '{atom_numbs}/sys.xxx'
151+
"""
152+
raise NotImplementedError(
153+
"%s doesn't support System.from" % (self.__class__.__name__)
154+
)

0 commit comments

Comments
 (0)