Skip to content

Commit f8a1b6b

Browse files
wanghan-iapcmHan Wang
andauthored
test: support comparison between two multi systems (#705)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced functions and classes to enhance testing capabilities for multi-system comparisons. - Added validation classes for periodic boundary conditions across multi-system objects. - **Bug Fixes** - Updated test classes to utilize new multi-system handling, improving clarity and functionality. - **Documentation** - Enhanced clarity in variable naming for better alignment with multi-system concepts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <[email protected]>
1 parent 676517a commit f8a1b6b

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

tests/comp_sys.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,72 @@ def test_virial(self):
105105
)
106106

107107

108+
def _make_comp_ms_test_func(comp_sys_test_func):
109+
"""
110+
Dynamically generates a test function for multi-system comparisons.
111+
112+
Args:
113+
comp_sys_test_func (Callable): The original test function for single systems.
114+
115+
Returns
116+
-------
117+
Callable: A new test function that can handle comparisons between multi-systems.
118+
"""
119+
120+
def comp_ms_test_func(iobj):
121+
assert hasattr(iobj, "ms_1") and hasattr(
122+
iobj, "ms_2"
123+
), "Multi-system objects must be present"
124+
iobj.assertEqual(len(iobj.ms_1), len(iobj.ms_2))
125+
keys = [ii.formula for ii in iobj.ms_1]
126+
keys_2 = [ii.formula for ii in iobj.ms_2]
127+
assert sorted(keys) == sorted(
128+
keys_2
129+
), f"Keys of two MS are not equal: {keys} != {keys_2}"
130+
for kk in keys:
131+
iobj.system_1 = iobj.ms_1[kk]
132+
iobj.system_2 = iobj.ms_2[kk]
133+
comp_sys_test_func(iobj)
134+
del iobj.system_1
135+
del iobj.system_2
136+
137+
return comp_ms_test_func
138+
139+
140+
def _make_comp_ms_class(comp_class):
141+
"""
142+
Dynamically generates a test class for multi-system comparisons.
143+
144+
Args:
145+
comp_class (type): The original test class for single systems.
146+
147+
Returns
148+
-------
149+
type: A new test class that can handle comparisons between multi-systems.
150+
"""
151+
152+
class CompMS:
153+
pass
154+
155+
test_methods = [
156+
func
157+
for func in dir(comp_class)
158+
if callable(getattr(comp_class, func)) and func.startswith("test_")
159+
]
160+
161+
for func in test_methods:
162+
setattr(CompMS, func, _make_comp_ms_test_func(getattr(comp_class, func)))
163+
164+
return CompMS
165+
166+
167+
# MultiSystems comparison from single System comparison
168+
CompMultiSys = _make_comp_ms_class(CompSys)
169+
170+
# LabeledMultiSystems comparison from single LabeledSystem comparison
171+
CompLabeledMultiSys = _make_comp_ms_class(CompLabeledSys)
172+
173+
108174
class MultiSystems:
109175
def test_systems_name(self):
110176
self.assertEqual(set(self.systems.systems), set(self.system_names))
@@ -127,3 +193,21 @@ class IsNoPBC:
127193
def test_is_nopbc(self):
128194
self.assertTrue(self.system_1.nopbc)
129195
self.assertTrue(self.system_2.nopbc)
196+
197+
198+
class MSAllIsPBC:
199+
def test_is_pbc(self):
200+
assert hasattr(self, "ms_1") and hasattr(
201+
self, "ms_2"
202+
), "Multi-system objects must be present and iterable"
203+
self.assertTrue(all([not ss.nopbc for ss in self.ms_1]))
204+
self.assertTrue(all([not ss.nopbc for ss in self.ms_2]))
205+
206+
207+
class MSAllIsNoPBC:
208+
def test_is_nopbc(self):
209+
assert hasattr(self, "ms_1") and hasattr(
210+
self, "ms_2"
211+
), "Multi-system objects must be present and iterable"
212+
self.assertTrue(all([ss.nopbc for ss in self.ms_1]))
213+
self.assertTrue(all([ss.nopbc for ss in self.ms_2]))

tests/test_deepmd_mixed.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66
from glob import glob
77

88
import numpy as np
9-
from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems
9+
from comp_sys import (
10+
CompLabeledMultiSys,
11+
CompLabeledSys,
12+
IsNoPBC,
13+
MSAllIsNoPBC,
14+
MultiSystems,
15+
)
1016
from context import dpdata
1117

1218

1319
class TestMixedMultiSystemsDumpLoad(
14-
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
20+
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
1521
):
1622
def setUp(self):
1723
self.places = 6
@@ -62,8 +68,8 @@ def setUp(self):
6268
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
6369
self.systems = dpdata.MultiSystems()
6470
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
65-
self.system_1 = self.ms["C1H4A0B0D0"]
66-
self.system_2 = self.systems["C1H4A0B0D0"]
71+
self.ms_1 = self.ms
72+
self.ms_2 = self.systems
6773
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
6874
self.assertEqual(len(mixed_sets), 2)
6975
for i in mixed_sets:
@@ -112,7 +118,7 @@ def test_str(self):
112118

113119

114120
class TestMixedMultiSystemsDumpLoadSetSize(
115-
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
121+
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
116122
):
117123
def setUp(self):
118124
self.places = 6
@@ -163,8 +169,8 @@ def setUp(self):
163169
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
164170
self.systems = dpdata.MultiSystems()
165171
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
166-
self.system_1 = self.ms["C1H4A0B0D0"]
167-
self.system_2 = self.systems["C1H4A0B0D0"]
172+
self.ms_1 = self.ms
173+
self.ms_2 = self.systems
168174
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
169175
self.assertEqual(len(mixed_sets), 5)
170176
for i in mixed_sets:
@@ -213,7 +219,7 @@ def test_str(self):
213219

214220

215221
class TestMixedMultiSystemsTypeChange(
216-
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
222+
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
217223
):
218224
def setUp(self):
219225
self.places = 6
@@ -265,8 +271,8 @@ def setUp(self):
265271
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
266272
self.systems = dpdata.MultiSystems(type_map=["TOKEN"])
267273
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
268-
self.system_1 = self.ms["TOKEN0C1H4A0B0D0"]
269-
self.system_2 = self.systems["TOKEN0C1H4A0B0D0"]
274+
self.ms_1 = self.ms
275+
self.ms_2 = self.systems
270276
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
271277
self.assertEqual(len(mixed_sets), 2)
272278
for i in mixed_sets:

0 commit comments

Comments
 (0)