Skip to content

Commit 1ccc57d

Browse files
authored
test(common): add regression for atom type remap (#5050)
## Summary - add a regression test that reproduces the original atom-type remapping failure when the training system contains fewer atom types than the provided type_map - ensure the fix from commit 4ec0437 remains covered going forward <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Added unit tests for DeepmdData that validate atom-type remapping, handling of unused types, and correct loading/formatting of type arrays. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 6fb0703 commit 1ccc57d

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import tempfile
3+
import unittest
4+
from pathlib import (
5+
Path,
6+
)
7+
8+
import numpy as np
9+
10+
from deepmd.utils.data import (
11+
DeepmdData,
12+
)
13+
14+
15+
class TestDeepmdDataTypeMap(unittest.TestCase):
16+
def setUp(self) -> None:
17+
self.tmpdir = tempfile.TemporaryDirectory()
18+
self.root = Path(self.tmpdir.name)
19+
self.set_dir = self.root / "set.000"
20+
self.set_dir.mkdir()
21+
22+
# minimal required dataset
23+
atom_types = np.array([0, 1, 0, 1], dtype=np.int32)
24+
np.savetxt(self.root / "type.raw", atom_types, fmt="%d")
25+
np.savetxt(
26+
self.root / "type_map.raw",
27+
np.array(["O", "H", "Si"], dtype=object),
28+
fmt="%s",
29+
)
30+
31+
coord = np.zeros((1, atom_types.size * 3), dtype=np.float32)
32+
box = np.eye(3, dtype=np.float32).reshape(1, 9)
33+
np.save(self.set_dir / "coord.npy", coord)
34+
np.save(self.set_dir / "box.npy", box)
35+
36+
def tearDown(self) -> None:
37+
self.tmpdir.cleanup()
38+
39+
def test_remap_with_unused_types(self) -> None:
40+
data = DeepmdData(str(self.root), type_map=["H", "O", "Si"])
41+
42+
expected_atom_types = np.array([1, 0, 1, 0], dtype=np.int32)
43+
np.testing.assert_array_equal(data.atom_type, expected_atom_types)
44+
self.assertEqual(data.type_map, ["H", "O", "Si"])
45+
46+
loaded = data._load_set(self.set_dir)
47+
expected_sorted = expected_atom_types[data.idx_map]
48+
np.testing.assert_array_equal(loaded["type"], np.tile(expected_sorted, (1, 1)))

0 commit comments

Comments
 (0)