Skip to content

Commit 6d082f1

Browse files
wanghan-iapcmHan Wang
andauthored
test: mixed data format: test if the index_map (when type_map is provided) works (#706)
consider the PR after #705 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced testing framework for multi-system comparisons, allowing for dynamic generation of test functions and classes. - Introduced new test classes to validate properties of multi-system objects regarding periodic boundary conditions. - Added a new test class for handling type mapping in labeled systems. - **Bug Fixes** - Updated existing test classes to improve clarity and consistency in naming conventions for multi-system variables. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <[email protected]>
1 parent f8a1b6b commit 6d082f1

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

tests/test_deepmd_mixed.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,119 @@ def test_str(self):
117117
)
118118

119119

120+
class TestMixedMultiSystemsDumpLoadTypeMap(
121+
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
122+
):
123+
def setUp(self):
124+
self.places = 6
125+
self.e_places = 6
126+
self.f_places = 6
127+
self.v_places = 6
128+
129+
# C1H4
130+
system_1 = dpdata.LabeledSystem(
131+
"gaussian/methane.gaussianlog", fmt="gaussian/log"
132+
)
133+
134+
# C1H3
135+
system_2 = dpdata.LabeledSystem(
136+
"gaussian/methane_sub.gaussianlog", fmt="gaussian/log"
137+
)
138+
139+
tmp_data = system_1.data.copy()
140+
tmp_data["atom_numbs"] = [1, 1, 1, 2]
141+
tmp_data["atom_names"] = ["C", "H", "A", "B"]
142+
tmp_data["atom_types"] = np.array([0, 1, 2, 3, 3])
143+
# C1H1A1B2
144+
system_1_modified_type_1 = dpdata.LabeledSystem(data=tmp_data)
145+
146+
tmp_data = system_1.data.copy()
147+
tmp_data["atom_numbs"] = [1, 1, 2, 1]
148+
tmp_data["atom_names"] = ["C", "H", "A", "B"]
149+
tmp_data["atom_types"] = np.array([0, 1, 2, 2, 3])
150+
# C1H1A2B1
151+
system_1_modified_type_2 = dpdata.LabeledSystem(data=tmp_data)
152+
153+
tmp_data = system_1.data.copy()
154+
tmp_data["atom_numbs"] = [1, 1, 1, 2]
155+
tmp_data["atom_names"] = ["C", "H", "A", "D"]
156+
tmp_data["atom_types"] = np.array([0, 1, 2, 3, 3])
157+
# C1H1A1C2
158+
system_1_modified_type_3 = dpdata.LabeledSystem(data=tmp_data)
159+
160+
self.ms = dpdata.MultiSystems(
161+
system_1,
162+
system_2,
163+
system_1_modified_type_1,
164+
system_1_modified_type_2,
165+
system_1_modified_type_3,
166+
)
167+
168+
self.ms.to_deepmd_npy_mixed("tmp.deepmd.mixed")
169+
self.place_holder_ms = dpdata.MultiSystems()
170+
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
171+
172+
new_type_map = ["H", "C", "D", "A", "B"]
173+
self.systems = dpdata.MultiSystems()
174+
self.systems.from_deepmd_npy_mixed(
175+
"tmp.deepmd.mixed", fmt="deepmd/npy/mixed", type_map=new_type_map
176+
)
177+
for kk in [ii.formula for ii in self.ms]:
178+
# apply type_map to each system
179+
self.ms[kk].apply_type_map(new_type_map)
180+
# revise keys in dict according because the type_map is updated.
181+
tmp_ss = self.ms.systems.pop(kk)
182+
self.ms.systems[tmp_ss.formula] = tmp_ss
183+
184+
self.ms_1 = self.ms
185+
self.ms_2 = self.systems
186+
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
187+
self.assertEqual(len(mixed_sets), 2)
188+
for i in mixed_sets:
189+
self.assertEqual(
190+
os.path.exists(os.path.join(i, "real_atom_types.npy")), True
191+
)
192+
193+
self.system_names = [
194+
"H4C1D0A0B0",
195+
"H3C1D0A0B0",
196+
"H1C1D0A1B2",
197+
"H1C1D0A2B1",
198+
"H1C1D2A1B0",
199+
]
200+
self.system_sizes = {
201+
"H4C1D0A0B0": 1,
202+
"H3C1D0A0B0": 1,
203+
"H1C1D0A1B2": 1,
204+
"H1C1D0A2B1": 1,
205+
"H1C1D2A1B0": 1,
206+
}
207+
self.atom_names = ["H", "C", "D", "A", "B"]
208+
209+
def tearDown(self):
210+
if os.path.exists("tmp.deepmd.mixed"):
211+
shutil.rmtree("tmp.deepmd.mixed")
212+
213+
def test_len(self):
214+
self.assertEqual(len(self.ms), 5)
215+
self.assertEqual(len(self.place_holder_ms), 2)
216+
self.assertEqual(len(self.systems), 5)
217+
218+
def test_get_nframes(self):
219+
self.assertEqual(self.ms.get_nframes(), 5)
220+
self.assertEqual(self.place_holder_ms.get_nframes(), 5)
221+
self.assertEqual(self.systems.get_nframes(), 5)
222+
223+
def test_str(self):
224+
self.assertEqual(str(self.ms), "MultiSystems (5 systems containing 5 frames)")
225+
self.assertEqual(
226+
str(self.place_holder_ms), "MultiSystems (2 systems containing 5 frames)"
227+
)
228+
self.assertEqual(
229+
str(self.systems), "MultiSystems (5 systems containing 5 frames)"
230+
)
231+
232+
120233
class TestMixedMultiSystemsDumpLoadSetSize(
121234
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
122235
):

0 commit comments

Comments
 (0)