Skip to content

Commit 5aade0d

Browse files
caic99pre-commit-ci[bot]iProzd
authored
Handle custom data types in mixed systems (#855)
## Summary - generalize mixed-system loader to split all registered data types - test fparam and aparam round-trip through deepmd mixed format ## Testing - `cd tests && pytest test_deepmd_mixed.py::TestMixedSystemWithFparamAparam -q` ------ https://chatgpt.com/codex/tasks/task_b_688b28aaa2d48332bd3a9bc79849b7a0 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enhanced support for custom registered data types in system data handling, allowing additional data arrays beyond standard types to be processed and saved. * Added comprehensive tests for new custom data types, verifying their correct saving, loading, and integrity within mixed system files. * **Tests** * Introduced new test cases to ensure custom data types are properly managed and validated in mixed system workflows. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Duo <[email protected]>
1 parent 597b8f6 commit 5aade0d

File tree

2 files changed

+175
-15
lines changed

2 files changed

+175
-15
lines changed

dpdata/deepmd/mixed.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import numpy as np
66

7+
import dpdata
8+
79
from .comp import dump as comp_dump
810
from .comp import to_system_data as comp_to_system_data
911

@@ -27,10 +29,32 @@ def to_system_data(folder, type_map=None, labels=True):
2729
all_real_atom_types_concat = index_map[all_real_atom_types_concat]
2830
all_cells_concat = data["cells"]
2931
all_coords_concat = data["coords"]
32+
33+
# handle custom registered data types
3034
if labels:
31-
all_eners_concat = data.get("energies")
32-
all_forces_concat = data.get("forces")
33-
all_virs_concat = data.get("virials")
35+
dtypes = dpdata.system.LabeledSystem.DTYPES
36+
else:
37+
dtypes = dpdata.system.System.DTYPES
38+
reserved = {
39+
"atom_numbs",
40+
"atom_names",
41+
"atom_types",
42+
"real_atom_names",
43+
"real_atom_types",
44+
"cells",
45+
"coords",
46+
"orig",
47+
"nopbc",
48+
}
49+
extra_data = {}
50+
for dtype in dtypes:
51+
name = dtype.name
52+
if name in reserved:
53+
continue
54+
if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES):
55+
continue
56+
if name in data:
57+
extra_data[name] = data.pop(name)
3458

3559
data_list = []
3660
while True:
@@ -56,16 +80,12 @@ def to_system_data(folder, type_map=None, labels=True):
5680
all_cells_concat = all_cells_concat[rest_idx]
5781
temp_data["coords"] = all_coords_concat[temp_idx]
5882
all_coords_concat = all_coords_concat[rest_idx]
59-
if labels:
60-
if all_eners_concat is not None and all_eners_concat.size > 0:
61-
temp_data["energies"] = all_eners_concat[temp_idx]
62-
all_eners_concat = all_eners_concat[rest_idx]
63-
if all_forces_concat is not None and all_forces_concat.size > 0:
64-
temp_data["forces"] = all_forces_concat[temp_idx]
65-
all_forces_concat = all_forces_concat[rest_idx]
66-
if all_virs_concat is not None and all_virs_concat.size > 0:
67-
temp_data["virials"] = all_virs_concat[temp_idx]
68-
all_virs_concat = all_virs_concat[rest_idx]
83+
84+
for name in extra_data:
85+
all_dtype_concat = extra_data[name]
86+
temp_data[name] = all_dtype_concat[temp_idx]
87+
extra_data[name] = all_dtype_concat[rest_idx]
88+
6989
data_list.append(temp_data)
7090
return data_list
7191

tests/test_deepmd_mixed.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
)
1616
from context import dpdata
1717

18+
from dpdata.data_type import (
19+
Axis,
20+
DataType,
21+
)
22+
1823

1924
class TestMixedMultiSystemsDumpLoad(
2025
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
@@ -455,5 +460,140 @@ def tearDown(self):
455460
shutil.rmtree("tmp.deepmd.mixed.single")
456461

457462

458-
if __name__ == "__main__":
459-
unittest.main()
463+
class TestMixedSystemWithFparamAparam(
464+
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
465+
):
466+
def setUp(self):
467+
self.places = 6
468+
self.e_places = 6
469+
self.f_places = 6
470+
self.v_places = 6
471+
472+
new_datatypes = [
473+
DataType(
474+
"fparam",
475+
np.ndarray,
476+
shape=(Axis.NFRAMES, 2),
477+
required=False,
478+
),
479+
DataType(
480+
"aparam",
481+
np.ndarray,
482+
shape=(Axis.NFRAMES, Axis.NATOMS, 3),
483+
required=False,
484+
),
485+
]
486+
487+
for datatype in new_datatypes:
488+
dpdata.System.register_data_type(datatype)
489+
dpdata.LabeledSystem.register_data_type(datatype)
490+
491+
# C1H4
492+
system_1 = dpdata.LabeledSystem(
493+
"gaussian/methane.gaussianlog", fmt="gaussian/log"
494+
)
495+
496+
# C1H3
497+
system_2 = dpdata.LabeledSystem(
498+
"gaussian/methane_sub.gaussianlog", fmt="gaussian/log"
499+
)
500+
501+
tmp_data_1 = system_1.data.copy()
502+
nframes_1 = tmp_data_1["coords"].shape[0]
503+
natoms_1 = tmp_data_1["atom_types"].shape[0]
504+
tmp_data_1["fparam"] = np.random.random([nframes_1, 2])
505+
tmp_data_1["aparam"] = np.random.random([nframes_1, natoms_1, 3])
506+
system_1_with_params = dpdata.LabeledSystem(data=tmp_data_1)
507+
508+
tmp_data_2 = system_2.data.copy()
509+
nframes_2 = tmp_data_2["coords"].shape[0]
510+
natoms_2 = tmp_data_2["atom_types"].shape[0]
511+
tmp_data_2["fparam"] = np.random.random([nframes_2, 2])
512+
tmp_data_2["aparam"] = np.random.random([nframes_2, natoms_2, 3])
513+
system_2_with_params = dpdata.LabeledSystem(data=tmp_data_2)
514+
515+
tmp_data_3 = system_1.data.copy()
516+
nframes_3 = tmp_data_3["coords"].shape[0]
517+
tmp_data_3["atom_numbs"] = [1, 1, 1, 2]
518+
tmp_data_3["atom_names"] = ["C", "H", "A", "B"]
519+
tmp_data_3["atom_types"] = np.array([0, 1, 2, 3, 3])
520+
natoms_3 = len(tmp_data_3["atom_types"])
521+
tmp_data_3["fparam"] = np.random.random([nframes_3, 2])
522+
tmp_data_3["aparam"] = np.random.random([nframes_3, natoms_3, 3])
523+
# C1H1A1B2 with params
524+
system_3_with_params = dpdata.LabeledSystem(data=tmp_data_3)
525+
526+
self.ms = dpdata.MultiSystems(
527+
system_1_with_params, system_2_with_params, system_3_with_params
528+
)
529+
530+
self.ms.to_deepmd_npy_mixed("tmp.deepmd.fparam.aparam")
531+
self.place_holder_ms = dpdata.MultiSystems()
532+
self.place_holder_ms.from_deepmd_npy(
533+
"tmp.deepmd.fparam.aparam", fmt="deepmd/npy"
534+
)
535+
self.systems = dpdata.MultiSystems()
536+
self.systems.from_deepmd_npy_mixed(
537+
"tmp.deepmd.fparam.aparam", fmt="deepmd/npy/mixed"
538+
)
539+
540+
self.ms_1 = self.ms
541+
self.ms_2 = self.systems
542+
543+
mixed_sets = glob("tmp.deepmd.fparam.aparam/*/set.*")
544+
for i in mixed_sets:
545+
self.assertEqual(
546+
os.path.exists(os.path.join(i, "real_atom_types.npy")), True
547+
)
548+
549+
self.system_names = ["C1H4A0B0", "C1H3A0B0", "C1H1A1B2"]
550+
self.system_sizes = {"C1H4A0B0": 1, "C1H3A0B0": 1, "C1H1A1B2": 1}
551+
self.atom_names = ["C", "H", "A", "B"]
552+
553+
def tearDown(self):
554+
if os.path.exists("tmp.deepmd.fparam.aparam"):
555+
shutil.rmtree("tmp.deepmd.fparam.aparam")
556+
557+
def test_len(self):
558+
self.assertEqual(len(self.ms), 3)
559+
self.assertEqual(len(self.systems), 3)
560+
561+
def test_get_nframes(self):
562+
self.assertEqual(self.ms.get_nframes(), 3)
563+
self.assertEqual(self.systems.get_nframes(), 3)
564+
565+
def test_str(self):
566+
self.assertEqual(str(self.ms), "MultiSystems (3 systems containing 3 frames)")
567+
self.assertEqual(
568+
str(self.systems), "MultiSystems (3 systems containing 3 frames)"
569+
)
570+
571+
def test_fparam_exists(self):
572+
for formula in self.system_names:
573+
if formula in self.ms.systems:
574+
self.assertTrue("fparam" in self.ms[formula].data)
575+
if formula in self.systems.systems:
576+
self.assertTrue("fparam" in self.systems[formula].data)
577+
578+
for formula in self.system_names:
579+
if formula in self.ms.systems and formula in self.systems.systems:
580+
np.testing.assert_almost_equal(
581+
self.ms[formula].data["fparam"],
582+
self.systems[formula].data["fparam"],
583+
decimal=self.places,
584+
)
585+
586+
def test_aparam_exists(self):
587+
for formula in self.system_names:
588+
if formula in self.ms.systems:
589+
self.assertTrue("aparam" in self.ms[formula].data)
590+
if formula in self.systems.systems:
591+
self.assertTrue("aparam" in self.systems[formula].data)
592+
593+
for formula in self.system_names:
594+
if formula in self.ms.systems and formula in self.systems.systems:
595+
np.testing.assert_almost_equal(
596+
self.ms[formula].data["aparam"],
597+
self.systems[formula].data["aparam"],
598+
decimal=self.places,
599+
)

0 commit comments

Comments
 (0)