Skip to content

Commit 676517a

Browse files
feat: customized dtypes for unlabeled deepmd (#702)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced flexibility in data type handling for system data processing based on the presence of labels. - New data type registration for "foo" with updated shape and configuration. - **Bug Fixes** - Improved adaptability and performance in data processing functions by refining dtype selection. - **Refactor** - Streamlined test structure for loading and dumping data types, improving organization and maintainability of test cases. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0a50f4d commit 676517a

File tree

5 files changed

+126
-81
lines changed

5 files changed

+126
-81
lines changed

dpdata/deepmd/comp.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,40 +67,43 @@ def to_system_data(folder, type_map=None, labels=True):
6767
data["virials"] = np.concatenate(all_virs, axis=0)
6868
# allow custom dtypes
6969
if labels:
70-
for dtype in dpdata.system.LabeledSystem.DTYPES:
71-
if dtype.name in (
72-
"atom_numbs",
73-
"atom_names",
74-
"atom_types",
75-
"orig",
76-
"cells",
77-
"coords",
78-
"real_atom_types",
79-
"real_atom_names",
80-
"nopbc",
81-
"energies",
82-
"forces",
83-
"virials",
84-
):
85-
# skip as these data contains specific rules
86-
continue
87-
if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES):
88-
warnings.warn(
89-
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format."
90-
)
91-
continue
92-
natoms = data["coords"].shape[1]
93-
shape = [
94-
natoms if xx == dpdata.system.Axis.NATOMS else xx
95-
for xx in dtype.shape[1:]
96-
]
97-
all_data = []
98-
for ii in sets:
99-
tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy"))
100-
if tmp is not None:
101-
all_data.append(np.reshape(tmp, [tmp.shape[0], *shape]))
102-
if len(all_data) > 0:
103-
data[dtype.name] = np.concatenate(all_data, axis=0)
70+
dtypes = dpdata.system.LabeledSystem.DTYPES
71+
else:
72+
dtypes = dpdata.system.System.DTYPES
73+
74+
for dtype in dtypes:
75+
if dtype.name in (
76+
"atom_numbs",
77+
"atom_names",
78+
"atom_types",
79+
"orig",
80+
"cells",
81+
"coords",
82+
"real_atom_types",
83+
"real_atom_names",
84+
"nopbc",
85+
"energies",
86+
"forces",
87+
"virials",
88+
):
89+
# skip as these data contains specific rules
90+
continue
91+
if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES):
92+
warnings.warn(
93+
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format."
94+
)
95+
continue
96+
natoms = data["coords"].shape[1]
97+
shape = [
98+
natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]
99+
]
100+
all_data = []
101+
for ii in sets:
102+
tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy"))
103+
if tmp is not None:
104+
all_data.append(np.reshape(tmp, [tmp.shape[0], *shape]))
105+
if len(all_data) > 0:
106+
data[dtype.name] = np.concatenate(all_data, axis=0)
104107
return data
105108

106109

@@ -173,6 +176,11 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
173176
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
174177
pass
175178
# allow custom dtypes
179+
labels = "energies" in data
180+
if labels:
181+
dtypes = dpdata.system.LabeledSystem.DTYPES
182+
else:
183+
dtypes = dpdata.system.System.DTYPES
176184
for dtype in dpdata.system.LabeledSystem.DTYPES:
177185
if dtype.name in (
178186
"atom_numbs",

dpdata/deepmd/hdf5.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ def to_system_data(
102102
},
103103
}
104104
# allow custom dtypes
105-
for dtype in dpdata.system.LabeledSystem.DTYPES:
105+
if labels:
106+
dtypes = dpdata.system.LabeledSystem.DTYPES
107+
else:
108+
dtypes = dpdata.system.System.DTYPES
109+
for dtype in dtypes:
106110
if dtype.name in (
107111
"atom_numbs",
108112
"atom_names",
@@ -210,8 +214,13 @@ def dump(
210214
"virials": {"fn": "virial", "shape": (nframes, 9), "dump": True},
211215
}
212216

217+
labels = "energies" in data
218+
if labels:
219+
dtypes = dpdata.system.LabeledSystem.DTYPES
220+
else:
221+
dtypes = dpdata.system.System.DTYPES
213222
# allow custom dtypes
214-
for dtype in dpdata.system.LabeledSystem.DTYPES:
223+
for dtype in dtypes:
215224
if dtype.name in (
216225
"atom_numbs",
217226
"atom_names",

dpdata/deepmd/raw.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -64,40 +64,41 @@ def to_system_data(folder, type_map=None, labels=True):
6464
data["nopbc"] = True
6565
# allow custom dtypes
6666
if labels:
67-
for dtype in dpdata.system.LabeledSystem.DTYPES:
68-
if dtype.name in (
69-
"atom_numbs",
70-
"atom_names",
71-
"atom_types",
72-
"orig",
73-
"cells",
74-
"coords",
75-
"real_atom_types",
76-
"real_atom_names",
77-
"nopbc",
78-
"energies",
79-
"forces",
80-
"virials",
81-
):
82-
# skip as these data contains specific rules
83-
continue
84-
if not (
85-
len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES
86-
):
87-
warnings.warn(
88-
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format."
89-
)
90-
continue
91-
natoms = data["coords"].shape[1]
92-
shape = [
93-
natoms if xx == dpdata.system.Axis.NATOMS else xx
94-
for xx in dtype.shape[1:]
95-
]
96-
if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")):
97-
data[dtype.name] = np.reshape(
98-
np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")),
99-
[nframes, *shape],
100-
)
67+
dtypes = dpdata.system.LabeledSystem.DTYPES
68+
else:
69+
dtypes = dpdata.system.System.DTYPES
70+
for dtype in dtypes:
71+
if dtype.name in (
72+
"atom_numbs",
73+
"atom_names",
74+
"atom_types",
75+
"orig",
76+
"cells",
77+
"coords",
78+
"real_atom_types",
79+
"real_atom_names",
80+
"nopbc",
81+
"energies",
82+
"forces",
83+
"virials",
84+
):
85+
# skip as these data contains specific rules
86+
continue
87+
if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES):
88+
warnings.warn(
89+
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format."
90+
)
91+
continue
92+
natoms = data["coords"].shape[1]
93+
shape = [
94+
natoms if xx == dpdata.system.Axis.NATOMS else xx
95+
for xx in dtype.shape[1:]
96+
]
97+
if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")):
98+
data[dtype.name] = np.reshape(
99+
np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")),
100+
[nframes, *shape],
101+
)
101102
return data
102103
else:
103104
raise RuntimeError("not dir " + folder)
@@ -144,7 +145,12 @@ def dump(folder, data):
144145
with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc:
145146
pass
146147
# allow custom dtypes
147-
for dtype in dpdata.system.LabeledSystem.DTYPES:
148+
labels = "energies" in data
149+
if labels:
150+
dtypes = dpdata.system.LabeledSystem.DTYPES
151+
else:
152+
dtypes = dpdata.system.System.DTYPES
153+
for dtype in dtypes:
148154
if dtype.name in (
149155
"atom_numbs",
150156
"atom_names",

tests/plugin/dpdata_plugin_test/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True
1111
)
1212

13+
register_data_type(
14+
DataType("foo", np.ndarray, (Axis.NFRAMES, 3, 3), required=False), labeled=False
15+
)
16+
1317
register_data_type(
1418
DataType("bar", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False),
1519
labeled=True,

tests/test_custom_data_type.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from dpdata.data_type import Axis, DataType
1010

1111

12-
class TestDeepmdLoadDumpComp(unittest.TestCase):
12+
class DeepmdLoadDumpCompTest:
1313
def setUp(self):
14-
self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
15-
self.foo = np.ones((len(self.system), 2, 4))
14+
self.system = self.cls(
15+
data=dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar").data
16+
)
17+
self.foo = np.ones((len(self.system), *self.shape))
1618
self.system.data["foo"] = self.foo
1719
self.system.check_data()
1820

@@ -23,7 +25,7 @@ def test_to_deepmd_raw(self):
2325

2426
def test_from_deepmd_raw(self):
2527
self.system.to_deepmd_raw("data_foo")
26-
x = dpdata.LabeledSystem("data_foo", fmt="deepmd/raw")
28+
x = self.cls("data_foo", fmt="deepmd/raw")
2729
np.testing.assert_allclose(x.data["foo"], self.foo)
2830

2931
def test_to_deepmd_npy(self):
@@ -33,7 +35,7 @@ def test_to_deepmd_npy(self):
3335

3436
def test_from_deepmd_npy(self):
3537
self.system.to_deepmd_npy("data_foo")
36-
x = dpdata.LabeledSystem("data_foo", fmt="deepmd/npy")
38+
x = self.cls("data_foo", fmt="deepmd/npy")
3739
np.testing.assert_allclose(x.data["foo"], self.foo)
3840

3941
def test_to_deepmd_hdf5(self):
@@ -44,18 +46,34 @@ def test_to_deepmd_hdf5(self):
4446

4547
def test_from_deepmd_hdf5(self):
4648
self.system.to_deepmd_hdf5("data_foo.h5")
47-
x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5")
49+
x = self.cls("data_foo.h5", fmt="deepmd/hdf5")
4850
np.testing.assert_allclose(x.data["foo"], self.foo)
4951

5052
def test_duplicated_data_type(self):
51-
dt = DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False)
52-
n_dtypes_old = len(dpdata.LabeledSystem.DTYPES)
53+
dt = DataType("foo", np.ndarray, (Axis.NFRAMES, *self.shape), required=False)
54+
n_dtypes_old = len(self.cls.DTYPES)
5355
with self.assertWarns(UserWarning):
54-
dpdata.LabeledSystem.register_data_type(dt)
55-
n_dtypes_new = len(dpdata.LabeledSystem.DTYPES)
56+
self.cls.register_data_type(dt)
57+
n_dtypes_new = len(self.cls.DTYPES)
5658
self.assertEqual(n_dtypes_old, n_dtypes_new)
5759

5860

61+
class TestDeepmdLoadDumpCompUnlabeled(unittest.TestCase, DeepmdLoadDumpCompTest):
62+
cls = dpdata.System
63+
shape = (3, 3)
64+
65+
def setUp(self):
66+
DeepmdLoadDumpCompTest.setUp(self)
67+
68+
69+
class TestDeepmdLoadDumpCompLabeled(unittest.TestCase, DeepmdLoadDumpCompTest):
70+
cls = dpdata.LabeledSystem
71+
shape = (2, 4)
72+
73+
def setUp(self):
74+
DeepmdLoadDumpCompTest.setUp(self)
75+
76+
5977
class TestDeepmdLoadDumpCompAny(unittest.TestCase):
6078
def setUp(self):
6179
self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")

0 commit comments

Comments
 (0)