Skip to content

Commit 35fdbb8

Browse files
njzjzwanghan-iapcm
andauthored
fix: reuse regular methods for deepmd/mixed (#704)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Streamlined data loading process with improved handling of atom types. - Enhanced modularization in data dumping, reducing complexity and improving maintainability. - Added a new test for verifying data conversion and loading processes for mixed format systems. - **Bug Fixes** - Improved clarity and efficiency in processing atom names and counts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: Han Wang <[email protected]>
1 parent 6d082f1 commit 35fdbb8

File tree

4 files changed

+31
-170
lines changed

4 files changed

+31
-170
lines changed

dpdata/deepmd/comp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def to_system_data(folder, type_map=None, labels=True):
7979
"orig",
8080
"cells",
8181
"coords",
82-
"real_atom_types",
8382
"real_atom_names",
8483
"nopbc",
8584
"energies",
@@ -189,7 +188,6 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
189188
"orig",
190189
"cells",
191190
"coords",
192-
"real_atom_types",
193191
"real_atom_names",
194192
"nopbc",
195193
"energies",

dpdata/deepmd/mixed.py

Lines changed: 17 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,16 @@
11
from __future__ import annotations
22

33
import copy
4-
import glob
5-
import os
6-
import shutil
74

85
import numpy as np
96

10-
11-
def load_type(folder):
12-
data = {}
13-
data["atom_names"] = []
14-
# if find type_map.raw, use it
15-
assert os.path.isfile(
16-
os.path.join(folder, "type_map.raw")
17-
), "Mixed type system must have type_map.raw!"
18-
with open(os.path.join(folder, "type_map.raw")) as fp:
19-
data["atom_names"] = fp.read().split()
20-
21-
return data
22-
23-
24-
def formula(atom_names, atom_numbs):
25-
"""Return the formula of this system, like C3H5O2."""
26-
return "".join([f"{symbol}{numb}" for symbol, numb in zip(atom_names, atom_numbs)])
27-
28-
29-
def _cond_load_data(fname):
30-
tmp = None
31-
if os.path.isfile(fname):
32-
tmp = np.load(fname)
33-
return tmp
34-
35-
36-
def _load_set(folder, nopbc: bool):
37-
coords = np.load(os.path.join(folder, "coord.npy"))
38-
if nopbc:
39-
cells = np.zeros((coords.shape[0], 3, 3))
40-
else:
41-
cells = np.load(os.path.join(folder, "box.npy"))
42-
eners = _cond_load_data(os.path.join(folder, "energy.npy"))
43-
forces = _cond_load_data(os.path.join(folder, "force.npy"))
44-
virs = _cond_load_data(os.path.join(folder, "virial.npy"))
45-
real_atom_types = np.load(os.path.join(folder, "real_atom_types.npy"))
46-
return cells, coords, eners, forces, virs, real_atom_types
7+
from .comp import dump as comp_dump
8+
from .comp import to_system_data as comp_to_system_data
479

4810

4911
def to_system_data(folder, type_map=None, labels=True):
12+
data = comp_to_system_data(folder, type_map, labels)
5013
# data is empty
51-
data = load_type(folder)
5214
old_type_map = data["atom_names"].copy()
5315
if type_map is not None:
5416
assert isinstance(type_map, list)
@@ -60,50 +22,16 @@ def to_system_data(folder, type_map=None, labels=True):
6022
data["atom_names"] = type_map.copy()
6123
else:
6224
index_map = None
63-
data["orig"] = np.zeros([3])
64-
if os.path.isfile(os.path.join(folder, "nopbc")):
65-
data["nopbc"] = True
66-
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
67-
all_cells = []
68-
all_coords = []
69-
all_eners = []
70-
all_forces = []
71-
all_virs = []
72-
all_real_atom_types = []
73-
for ii in sets:
74-
cells, coords, eners, forces, virs, real_atom_types = _load_set(
75-
ii, data.get("nopbc", False)
76-
)
77-
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
78-
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
79-
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
80-
if index_map is None:
81-
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
82-
else:
83-
all_real_atom_types.append(
84-
np.reshape(index_map[real_atom_types], [nframes, -1])
85-
)
86-
if eners is not None:
87-
eners = np.reshape(eners, [nframes])
88-
if labels:
89-
if eners is not None and eners.size > 0:
90-
all_eners.append(np.reshape(eners, [nframes]))
91-
if forces is not None and forces.size > 0:
92-
all_forces.append(np.reshape(forces, [nframes, -1, 3]))
93-
if virs is not None and virs.size > 0:
94-
all_virs.append(np.reshape(virs, [nframes, 3, 3]))
95-
all_cells_concat = np.concatenate(all_cells, axis=0)
96-
all_coords_concat = np.concatenate(all_coords, axis=0)
97-
all_real_atom_types_concat = np.concatenate(all_real_atom_types, axis=0)
98-
all_eners_concat = None
99-
all_forces_concat = None
100-
all_virs_concat = None
101-
if len(all_eners) > 0:
102-
all_eners_concat = np.concatenate(all_eners, axis=0)
103-
if len(all_forces) > 0:
104-
all_forces_concat = np.concatenate(all_forces, axis=0)
105-
if len(all_virs) > 0:
106-
all_virs_concat = np.concatenate(all_virs, axis=0)
25+
all_real_atom_types_concat = data.pop("real_atom_types").astype(int)
26+
if index_map is not None:
27+
all_real_atom_types_concat = index_map[all_real_atom_types_concat]
28+
all_cells_concat = data["cells"]
29+
all_coords_concat = data["coords"]
30+
if labels:
31+
all_eners_concat = data.get("energies")
32+
all_forces_concat = data.get("forces")
33+
all_virs_concat = data.get("virials")
34+
10735
data_list = []
10836
while True:
10937
if all_real_atom_types_concat.size == 0:
@@ -143,20 +71,6 @@ def to_system_data(folder, type_map=None, labels=True):
14371

14472

14573
def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
146-
os.makedirs(folder, exist_ok=True)
147-
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
148-
if len(sets) > 0:
149-
if remove_sets:
150-
for ii in sets:
151-
shutil.rmtree(ii)
152-
else:
153-
raise RuntimeError(
154-
"found "
155-
+ str(sets)
156-
+ " in "
157-
+ folder
158-
+ "not a clean deepmd raw dir. please firstly clean set.* then try compress"
159-
)
16074
# if not converted to mixed
16175
if "real_atom_types" not in data:
16276
from dpdata import LabeledSystem, System
@@ -169,69 +83,10 @@ def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
16983
else:
17084
temp_sys = System(data=data)
17185
temp_sys.convert_to_mixed_type()
172-
# dump raw
173-
np.savetxt(os.path.join(folder, "type.raw"), data["atom_types"], fmt="%d")
174-
np.savetxt(os.path.join(folder, "type_map.raw"), data["real_atom_names"], fmt="%s")
175-
# BondOrder System
176-
if "bonds" in data:
177-
np.savetxt(
178-
os.path.join(folder, "bonds.raw"),
179-
data["bonds"],
180-
header="begin_atom, end_atom, bond_order",
181-
)
182-
if "formal_charges" in data:
183-
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
184-
# reshape frame properties and convert prec
185-
nframes = data["cells"].shape[0]
186-
cells = np.reshape(data["cells"], [nframes, 9]).astype(comp_prec)
187-
coords = np.reshape(data["coords"], [nframes, -1]).astype(comp_prec)
188-
eners = None
189-
forces = None
190-
virials = None
191-
real_atom_types = None
192-
if "energies" in data:
193-
eners = np.reshape(data["energies"], [nframes]).astype(comp_prec)
194-
if "forces" in data:
195-
forces = np.reshape(data["forces"], [nframes, -1]).astype(comp_prec)
196-
if "virials" in data:
197-
virials = np.reshape(data["virials"], [nframes, 9]).astype(comp_prec)
198-
if "atom_pref" in data:
199-
atom_pref = np.reshape(data["atom_pref"], [nframes, -1]).astype(comp_prec)
200-
if "real_atom_types" in data:
201-
real_atom_types = np.reshape(data["real_atom_types"], [nframes, -1]).astype(
202-
np.int64
203-
)
204-
# dump frame properties: cell, coord, energy, force and virial
205-
nsets = nframes // set_size
206-
if set_size * nsets < nframes:
207-
nsets += 1
208-
for ii in range(nsets):
209-
set_stt = ii * set_size
210-
set_end = (ii + 1) * set_size
211-
set_folder = os.path.join(folder, "set.%06d" % ii)
212-
os.makedirs(set_folder)
213-
np.save(os.path.join(set_folder, "box"), cells[set_stt:set_end])
214-
np.save(os.path.join(set_folder, "coord"), coords[set_stt:set_end])
215-
if eners is not None:
216-
np.save(os.path.join(set_folder, "energy"), eners[set_stt:set_end])
217-
if forces is not None:
218-
np.save(os.path.join(set_folder, "force"), forces[set_stt:set_end])
219-
if virials is not None:
220-
np.save(os.path.join(set_folder, "virial"), virials[set_stt:set_end])
221-
if real_atom_types is not None:
222-
np.save(
223-
os.path.join(set_folder, "real_atom_types"),
224-
real_atom_types[set_stt:set_end],
225-
)
226-
if "atom_pref" in data:
227-
np.save(os.path.join(set_folder, "atom_pref"), atom_pref[set_stt:set_end])
228-
try:
229-
os.remove(os.path.join(folder, "nopbc"))
230-
except OSError:
231-
pass
232-
if data.get("nopbc", False):
233-
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
234-
pass
86+
87+
data = data.copy()
88+
data["atom_names"] = data.pop("real_atom_names")
89+
comp_dump(folder, data, set_size, comp_prec, remove_sets)
23590

23691

23792
def mix_system(*system, type_map, **kwargs):

dpdata/deepmd/raw.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ def load_type(folder, type_map=None):
1414
int
1515
)
1616
ntypes = np.max(data["atom_types"]) + 1
17-
data["atom_numbs"] = []
18-
for ii in range(ntypes):
19-
data["atom_numbs"].append(np.count_nonzero(data["atom_types"] == ii))
2017
data["atom_names"] = []
2118
# if find type_map.raw, use it
2219
if os.path.isfile(os.path.join(folder, "type_map.raw")):
@@ -30,9 +27,10 @@ def load_type(folder, type_map=None):
3027
my_type_map = []
3128
for ii in range(ntypes):
3229
my_type_map.append("Type_%d" % ii)
33-
assert len(my_type_map) >= len(data["atom_numbs"])
34-
for ii in range(len(data["atom_numbs"])):
35-
data["atom_names"].append(my_type_map[ii])
30+
data["atom_names"] = my_type_map
31+
data["atom_numbs"] = []
32+
for ii, _ in enumerate(data["atom_names"]):
33+
data["atom_numbs"].append(np.count_nonzero(data["atom_types"] == ii))
3634

3735
return data
3836

tests/test_custom_data_type.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ def test_duplicated_data_type(self):
5757
n_dtypes_new = len(self.cls.DTYPES)
5858
self.assertEqual(n_dtypes_old, n_dtypes_new)
5959

60+
def test_to_deepmd_npy_mixed(self):
61+
ms = dpdata.MultiSystems(self.system)
62+
ms.to_deepmd_npy_mixed("data_foo_mixed")
63+
x = dpdata.MultiSystems().load_systems_from_file(
64+
"data_foo_mixed",
65+
fmt="deepmd/npy/mixed",
66+
labeled=issubclass(self.cls, dpdata.LabeledSystem),
67+
)
68+
np.testing.assert_allclose(list(x.systems.values())[0].data["foo"], self.foo)
69+
6070

6171
class TestDeepmdLoadDumpCompUnlabeled(unittest.TestCase, DeepmdLoadDumpCompTest):
6272
cls = dpdata.System

0 commit comments

Comments
 (0)