Skip to content

Commit a2fbdd8

Browse files
authored
feat: support data type dumped to a different name (#727)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an optional `deepmd_name` parameter in the `DataType` class for enhanced naming flexibility. - Updated data type declarations in the `System` and `LabeledSystem` classes for better integration with the DeepMD framework. - **Bug Fixes** - Removed handling of energy, force, and virial data to simplify data processing and storage. - **Documentation** - Updated documentation for the `DataType` class to clarify the new `deepmd_name` parameter. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 480242e commit a2fbdd8

File tree

5 files changed

+35
-157
lines changed

5 files changed

+35
-157
lines changed

dpdata/data_type.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class DataType:
4646
represents numbers
4747
required : bool, default=True
4848
whether this data is required
49+
deepmd_name : str, optional
50+
DeePMD-kit data type name. When not given, it is the same as `name`.
4951
"""
5052

5153
def __init__(
@@ -54,11 +56,13 @@ def __init__(
5456
dtype: type,
5557
shape: tuple[int | Axis, ...] | None = None,
5658
required: bool = True,
59+
deepmd_name: str | None = None,
5760
) -> None:
5861
self.name = name
5962
self.dtype = dtype
6063
self.shape = shape
6164
self.required = required
65+
self.deepmd_name = name if deepmd_name is None else deepmd_name
6266

6367
def real_shape(self, system: System) -> tuple[int]:
6468
"""Returns expected real shape of a system."""

dpdata/deepmd/comp.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ def _load_set(folder, nopbc: bool):
2626
cells = np.zeros((coords.shape[0], 3, 3))
2727
else:
2828
cells = np.load(os.path.join(folder, "box.npy"))
29-
eners = _cond_load_data(os.path.join(folder, "energy.npy"))
30-
forces = _cond_load_data(os.path.join(folder, "force.npy"))
31-
virs = _cond_load_data(os.path.join(folder, "virial.npy"))
32-
return cells, coords, eners, forces, virs
29+
return cells, coords
3330

3431

3532
def to_system_data(folder, type_map=None, labels=True):
@@ -41,31 +38,13 @@ def to_system_data(folder, type_map=None, labels=True):
4138
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
4239
all_cells = []
4340
all_coords = []
44-
all_eners = []
45-
all_forces = []
46-
all_virs = []
4741
for ii in sets:
48-
cells, coords, eners, forces, virs = _load_set(ii, data.get("nopbc", False))
42+
cells, coords = _load_set(ii, data.get("nopbc", False))
4943
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
5044
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
5145
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
52-
if eners is not None:
53-
eners = np.reshape(eners, [nframes])
54-
if labels:
55-
if eners is not None and eners.size > 0:
56-
all_eners.append(np.reshape(eners, [nframes]))
57-
if forces is not None and forces.size > 0:
58-
all_forces.append(np.reshape(forces, [nframes, -1, 3]))
59-
if virs is not None and virs.size > 0:
60-
all_virs.append(np.reshape(virs, [nframes, 3, 3]))
6146
data["cells"] = np.concatenate(all_cells, axis=0)
6247
data["coords"] = np.concatenate(all_coords, axis=0)
63-
if len(all_eners) > 0:
64-
data["energies"] = np.concatenate(all_eners, axis=0)
65-
if len(all_forces) > 0:
66-
data["forces"] = np.concatenate(all_forces, axis=0)
67-
if len(all_virs) > 0:
68-
data["virials"] = np.concatenate(all_virs, axis=0)
6948
# allow custom dtypes
7049
if labels:
7150
dtypes = dpdata.system.LabeledSystem.DTYPES
@@ -82,9 +61,6 @@ def to_system_data(folder, type_map=None, labels=True):
8261
"coords",
8362
"real_atom_names",
8463
"nopbc",
85-
"energies",
86-
"forces",
87-
"virials",
8864
):
8965
# skip as these data contains specific rules
9066
continue
@@ -93,13 +69,13 @@ def to_system_data(folder, type_map=None, labels=True):
9369
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format."
9470
)
9571
continue
96-
natoms = data["coords"].shape[1]
72+
natoms = data["atom_types"].shape[0]
9773
shape = [
9874
natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]
9975
]
10076
all_data = []
10177
for ii in sets:
102-
tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy"))
78+
tmp = _cond_load_data(os.path.join(ii, dtype.deepmd_name + ".npy"))
10379
if tmp is not None:
10480
all_data.append(np.reshape(tmp, [tmp.shape[0], *shape]))
10581
if len(all_data) > 0:
@@ -136,19 +112,6 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
136112
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
137113
# reshape frame properties and convert prec
138114
nframes = data["cells"].shape[0]
139-
cells = np.reshape(data["cells"], [nframes, 9]).astype(comp_prec)
140-
coords = np.reshape(data["coords"], [nframes, -1]).astype(comp_prec)
141-
eners = None
142-
forces = None
143-
virials = None
144-
if "energies" in data:
145-
eners = np.reshape(data["energies"], [nframes]).astype(comp_prec)
146-
if "forces" in data:
147-
forces = np.reshape(data["forces"], [nframes, -1]).astype(comp_prec)
148-
if "virials" in data:
149-
virials = np.reshape(data["virials"], [nframes, 9]).astype(comp_prec)
150-
if "atom_pref" in data:
151-
atom_pref = np.reshape(data["atom_pref"], [nframes, -1]).astype(comp_prec)
152115
# dump frame properties: cell, coord, energy, force and virial
153116
nsets = nframes // set_size
154117
if set_size * nsets < nframes:
@@ -158,16 +121,6 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
158121
set_end = (ii + 1) * set_size
159122
set_folder = os.path.join(folder, "set.%03d" % ii)
160123
os.makedirs(set_folder)
161-
np.save(os.path.join(set_folder, "box"), cells[set_stt:set_end])
162-
np.save(os.path.join(set_folder, "coord"), coords[set_stt:set_end])
163-
if eners is not None:
164-
np.save(os.path.join(set_folder, "energy"), eners[set_stt:set_end])
165-
if forces is not None:
166-
np.save(os.path.join(set_folder, "force"), forces[set_stt:set_end])
167-
if virials is not None:
168-
np.save(os.path.join(set_folder, "virial"), virials[set_stt:set_end])
169-
if "atom_pref" in data:
170-
np.save(os.path.join(set_folder, "atom_pref"), atom_pref[set_stt:set_end])
171124
try:
172125
os.remove(os.path.join(folder, "nopbc"))
173126
except OSError:
@@ -187,13 +140,8 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
187140
"atom_names",
188141
"atom_types",
189142
"orig",
190-
"cells",
191-
"coords",
192143
"real_atom_names",
193144
"nopbc",
194-
"energies",
195-
"forces",
196-
"virials",
197145
):
198146
# skip as these data contains specific rules
199147
continue
@@ -211,4 +159,4 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
211159
set_stt = ii * set_size
212160
set_end = (ii + 1) * set_size
213161
set_folder = os.path.join(folder, "set.%03d" % ii)
214-
np.save(os.path.join(set_folder, dtype.name), ddata[set_stt:set_end])
162+
np.save(os.path.join(set_folder, dtype.deepmd_name), ddata[set_stt:set_end])

dpdata/deepmd/hdf5.py

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -69,38 +69,7 @@ def to_system_data(
6969
data["nopbc"] = True
7070
sets = globfilter(g.keys(), "set.*")
7171

72-
data_types = {
73-
"cells": {
74-
"fn": "box",
75-
"labeled": False,
76-
"shape": (3, 3),
77-
"required": "nopbc" not in data,
78-
},
79-
"coords": {
80-
"fn": "coord",
81-
"labeled": False,
82-
"shape": (natoms, 3),
83-
"required": True,
84-
},
85-
"energies": {
86-
"fn": "energy",
87-
"labeled": True,
88-
"shape": tuple(),
89-
"required": False,
90-
},
91-
"forces": {
92-
"fn": "force",
93-
"labeled": True,
94-
"shape": (natoms, 3),
95-
"required": False,
96-
},
97-
"virials": {
98-
"fn": "virial",
99-
"labeled": True,
100-
"shape": (3, 3),
101-
"required": False,
102-
},
103-
}
72+
data_types = {}
10473
# allow custom dtypes
10574
if labels:
10675
dtypes = dpdata.system.LabeledSystem.DTYPES
@@ -112,14 +81,9 @@ def to_system_data(
11281
"atom_names",
11382
"atom_types",
11483
"orig",
115-
"cells",
116-
"coords",
11784
"real_atom_types",
11885
"real_atom_names",
11986
"nopbc",
120-
"energies",
121-
"forces",
122-
"virials",
12387
):
12488
# skip as these data contains specific rules
12589
continue
@@ -133,10 +97,10 @@ def to_system_data(
13397
]
13498

13599
data_types[dtype.name] = {
136-
"fn": dtype.name,
137-
"labeled": True,
100+
"fn": dtype.deepmd_name,
138101
"shape": shape,
139-
"required": False,
102+
"required": dtype.required
103+
and not (dtype.name == "cells" and data.get("nopbc", False)),
140104
}
141105

142106
for dt, prop in data_types.items():
@@ -206,13 +170,7 @@ def dump(
206170
nopbc = data.get("nopbc", False)
207171
reshaped_data = {}
208172

209-
data_types = {
210-
"cells": {"fn": "box", "shape": (nframes, 9), "dump": not nopbc},
211-
"coords": {"fn": "coord", "shape": (nframes, -1), "dump": True},
212-
"energies": {"fn": "energy", "shape": (nframes,), "dump": True},
213-
"forces": {"fn": "force", "shape": (nframes, -1), "dump": True},
214-
"virials": {"fn": "virial", "shape": (nframes, 9), "dump": True},
215-
}
173+
data_types = {}
216174

217175
labels = "energies" in data
218176
if labels:
@@ -226,14 +184,9 @@ def dump(
226184
"atom_names",
227185
"atom_types",
228186
"orig",
229-
"cells",
230-
"coords",
231187
"real_atom_types",
232188
"real_atom_names",
233189
"nopbc",
234-
"energies",
235-
"forces",
236-
"virials",
237190
):
238191
# skip as these data contains specific rules
239192
continue
@@ -244,9 +197,9 @@ def dump(
244197
continue
245198

246199
data_types[dtype.name] = {
247-
"fn": dtype.name,
200+
"fn": dtype.deepmd_name,
248201
"shape": (nframes, -1),
249-
"dump": True,
202+
"dump": not (dtype.name == "cells" and nopbc),
250203
}
251204

252205
for dt, prop in data_types.items():

dpdata/deepmd/raw.py

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,6 @@ def to_system_data(folder, type_map=None, labels=True):
4949
data["cells"] = np.loadtxt(os.path.join(folder, "box.raw"), ndmin=2)
5050
data["cells"] = np.reshape(data["cells"], [nframes, 3, 3])
5151
data["coords"] = np.reshape(data["coords"], [nframes, -1, 3])
52-
if labels:
53-
if os.path.exists(os.path.join(folder, "energy.raw")):
54-
data["energies"] = np.loadtxt(os.path.join(folder, "energy.raw"))
55-
data["energies"] = np.reshape(data["energies"], [nframes])
56-
if os.path.exists(os.path.join(folder, "force.raw")):
57-
data["forces"] = np.loadtxt(os.path.join(folder, "force.raw"))
58-
data["forces"] = np.reshape(data["forces"], [nframes, -1, 3])
59-
if os.path.exists(os.path.join(folder, "virial.raw")):
60-
data["virials"] = np.loadtxt(os.path.join(folder, "virial.raw"))
61-
data["virials"] = np.reshape(data["virials"], [nframes, 3, 3])
6252
if os.path.isfile(os.path.join(folder, "nopbc")):
6353
data["nopbc"] = True
6454
# allow custom dtypes
@@ -77,9 +67,6 @@ def to_system_data(folder, type_map=None, labels=True):
7767
"real_atom_types",
7868
"real_atom_names",
7969
"nopbc",
80-
"energies",
81-
"forces",
82-
"virials",
8370
):
8471
# skip as these data contains specific rules
8572
continue
@@ -88,14 +75,14 @@ def to_system_data(folder, type_map=None, labels=True):
8875
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format."
8976
)
9077
continue
91-
natoms = data["coords"].shape[1]
78+
natoms = data["atom_types"].shape[0]
9279
shape = [
9380
natoms if xx == dpdata.system.Axis.NATOMS else xx
9481
for xx in dtype.shape[1:]
9582
]
96-
if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")):
83+
if os.path.exists(os.path.join(folder, f"{dtype.deepmd_name}.raw")):
9784
data[dtype.name] = np.reshape(
98-
np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")),
85+
np.loadtxt(os.path.join(folder, f"{dtype.deepmd_name}.raw")),
9986
[nframes, *shape],
10087
)
10188
return data
@@ -108,10 +95,6 @@ def dump(folder, data):
10895
nframes = data["cells"].shape[0]
10996
np.savetxt(os.path.join(folder, "type.raw"), data["atom_types"], fmt="%d")
11097
np.savetxt(os.path.join(folder, "type_map.raw"), data["atom_names"], fmt="%s")
111-
np.savetxt(os.path.join(folder, "box.raw"), np.reshape(data["cells"], [nframes, 9]))
112-
np.savetxt(
113-
os.path.join(folder, "coord.raw"), np.reshape(data["coords"], [nframes, -1])
114-
)
11598
# BondOrder System
11699
if "bonds" in data:
117100
np.savetxt(
@@ -121,21 +104,6 @@ def dump(folder, data):
121104
)
122105
if "formal_charges" in data:
123106
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
124-
# Labeled System
125-
if "energies" in data:
126-
np.savetxt(
127-
os.path.join(folder, "energy.raw"),
128-
np.reshape(data["energies"], [nframes, 1]),
129-
)
130-
if "forces" in data:
131-
np.savetxt(
132-
os.path.join(folder, "force.raw"), np.reshape(data["forces"], [nframes, -1])
133-
)
134-
if "virials" in data:
135-
np.savetxt(
136-
os.path.join(folder, "virial.raw"),
137-
np.reshape(data["virials"], [nframes, 9]),
138-
)
139107
try:
140108
os.remove(os.path.join(folder, "nopbc"))
141109
except OSError:
@@ -155,14 +123,9 @@ def dump(folder, data):
155123
"atom_names",
156124
"atom_types",
157125
"orig",
158-
"cells",
159-
"coords",
160126
"real_atom_types",
161127
"real_atom_names",
162128
"nopbc",
163-
"energies",
164-
"forces",
165-
"virials",
166129
):
167130
# skip as these data contains specific rules
168131
continue
@@ -174,4 +137,4 @@ def dump(folder, data):
174137
)
175138
continue
176139
ddata = np.reshape(data[dtype.name], [nframes, -1])
177-
np.savetxt(os.path.join(folder, f"{dtype.name}.raw"), ddata)
140+
np.savetxt(os.path.join(folder, f"{dtype.deepmd_name}.raw"), ddata)

dpdata/system.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ class System:
9191
DataType("atom_names", list, (Axis.NTYPES,)),
9292
DataType("atom_types", np.ndarray, (Axis.NATOMS,)),
9393
DataType("orig", np.ndarray, (3,)),
94-
DataType("cells", np.ndarray, (Axis.NFRAMES, 3, 3)),
95-
DataType("coords", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)),
94+
DataType("cells", np.ndarray, (Axis.NFRAMES, 3, 3), deepmd_name="box"),
95+
DataType(
96+
"coords", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="coord"
97+
),
9698
DataType(
9799
"real_atom_types", np.ndarray, (Axis.NFRAMES, Axis.NATOMS), required=False
98100
),
@@ -1204,9 +1206,17 @@ class LabeledSystem(System):
12041206
"""
12051207

12061208
DTYPES: tuple[DataType, ...] = System.DTYPES + (
1207-
DataType("energies", np.ndarray, (Axis.NFRAMES,)),
1208-
DataType("forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)),
1209-
DataType("virials", np.ndarray, (Axis.NFRAMES, 3, 3), required=False),
1209+
DataType("energies", np.ndarray, (Axis.NFRAMES,), deepmd_name="energy"),
1210+
DataType(
1211+
"forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force"
1212+
),
1213+
DataType(
1214+
"virials",
1215+
np.ndarray,
1216+
(Axis.NFRAMES, 3, 3),
1217+
required=False,
1218+
deepmd_name="virial",
1219+
),
12101220
DataType("atom_pref", np.ndarray, (Axis.NFRAMES, Axis.NATOMS), required=False),
12111221
)
12121222

0 commit comments

Comments
 (0)