Skip to content

Commit bba3e9f

Browse files
wanghan-iapcmHan Wangpre-commit-ci[bot]
authored
fix: the replicate will fail if the atom types of system is not sorted (#667)
- add UT to detect the issue. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved the logic for replicating atom types in the system, ensuring more accurate replication behavior. - **Tests** - Added new tests to verify the replication of atom types, enhancing test coverage and reliability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 199afc1 commit bba3e9f

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

dpdata/system.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,11 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]):
776776
tmp.data["atom_numbs"] = list(
777777
np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy)
778778
)
779-
tmp.data["atom_types"] = np.sort(
780-
np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable"
779+
tmp.data["atom_types"] = np.tile(
780+
np.copy(data["atom_types"]), (int(np.prod(ncopy)),) + (1,)
781781
)
782+
tmp.data["atom_types"] = np.transpose(tmp.data["atom_types"]).reshape([-1])
783+
782784
tmp.data["cells"] = np.copy(data["cells"])
783785
for ii in range(3):
784786
tmp.data["cells"][:, ii, :] *= ncopy[ii]

tests/test_predict.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def setUp(self):
106106
)
107107
zero_driver = ZeroDriver()
108108
self.system_1 = ori_sys.predict(driver=zero_driver)
109-
self.system_2 = ori_sys.minimize(driver=zero_driver, minimizer="ase")
109+
self.system_2 = ori_sys.minimize(
110+
driver=zero_driver, minimizer="ase", max_steps=100
111+
)
110112
self.places = 6
111113
self.e_places = 6
112114
self.f_places = 6
@@ -123,7 +125,9 @@ def setUp(self):
123125
zero_driver = ZeroDriver()
124126
self.system_1 = list(multi_sys.predict(driver=zero_driver).systems.values())[0]
125127
self.system_2 = list(
126-
multi_sys.minimize(driver=zero_driver, minimizer="ase").systems.values()
128+
multi_sys.minimize(
129+
driver=zero_driver, minimizer="ase", max_steps=100
130+
).systems.values()
127131
)[0]
128132
self.places = 6
129133
self.e_places = 6

tests/test_replicate.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import unittest
44

5+
import numpy as np
56
from comp_sys import CompSys, IsPBC
67
from context import dpdata
78

@@ -36,5 +37,32 @@ def setUp(self):
3637
self.places = 6
3738

3839

40+
class TestReplicateTriclinicBox(unittest.TestCase, CompSys, IsPBC):
41+
def setUp(self):
42+
self.system_1 = dpdata.System()
43+
self.system_1.data["atom_names"] = ["foo", "bar"]
44+
self.system_1.data["atom_types"] = np.array([1, 0], dtype=int)
45+
self.system_1.data["atom_numbs"] = [1, 1]
46+
self.system_1.data["cells"] = np.array(
47+
[10, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float
48+
).reshape(1, 3, 3)
49+
self.system_1.data["coords"] = np.array(
50+
[0, 0, 0, 0, 0, 1], dtype=float
51+
).reshape(1, 2, 3)
52+
self.system_1 = self.system_1.replicate([2, 1, 1])
53+
54+
self.system_2 = dpdata.System()
55+
self.system_2.data["atom_names"] = ["foo", "bar"]
56+
self.system_2.data["atom_types"] = np.array([1, 1, 0, 0], dtype=int)
57+
self.system_2.data["atom_numbs"] = [2, 2]
58+
self.system_2.data["cells"] = np.array(
59+
[20, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float
60+
).reshape(1, 3, 3)
61+
self.system_2.data["coords"] = np.array(
62+
[0, 0, 0, 10, 0, 0, 0, 0, 1, 10, 0, 1], dtype=float
63+
).reshape(1, 4, 3)
64+
self.places = 6
65+
66+
3967
if __name__ == "__main__":
4068
unittest.main()

0 commit comments

Comments
 (0)