Skip to content

Commit fb942bb

Browse files
Feat: Support specifying proportion of atoms to be perturbed in System (#716)
See title. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new parameter for controlled atom perturbation in the perturb function, enhancing flexibility. - **Bug Fixes** - Improved logic for selecting atoms to perturb, ensuring only a specified proportion is affected. - **Tests** - Added a new test class to validate the perturbation functionality for atomic systems, increasing test coverage and reliability. - Introduced a structured representation of a Silicon Carbide crystal for validation in tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1de5ace commit fb942bb

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

dpdata/system.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,7 @@ def perturb(
849849
cell_pert_fraction: float,
850850
atom_pert_distance: float,
851851
atom_pert_style: str = "normal",
852+
atom_pert_prob: float = 1.0,
852853
):
853854
"""Perturb each frame in the system randomly.
854855
The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction.
@@ -877,6 +878,8 @@ def perturb(
877878
These points are treated as vector used by atoms to move.
878879
Obviously, the max length of the distance atoms move is `atom_pert_distance`.
879880
- `'const'`: The distance atoms move will be a constant `atom_pert_distance`.
881+
atom_pert_prob : float
882+
Determine the proportion of the total number of atoms in a frame that are perturbed.
880883
881884
Returns
882885
-------
@@ -900,7 +903,15 @@ def perturb(
900903
tmp_system.data["coords"][0] = np.matmul(
901904
tmp_system.data["coords"][0], cell_perturb_matrix
902905
)
903-
for kk in range(len(tmp_system.data["coords"][0])):
906+
pert_natoms = int(atom_pert_prob * len(tmp_system.data["coords"][0]))
907+
pert_atom_id = sorted(
908+
np.random.choice(
909+
range(len(tmp_system.data["coords"][0])),
910+
pert_natoms,
911+
replace=False,
912+
).tolist()
913+
)
914+
for kk in pert_atom_id:
904915
atom_perturb_vector = get_atom_perturb_vector(
905916
atom_pert_distance, atom_pert_style
906917
)

tests/poscars/POSCAR.SiC.partpert

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
C4 Si4
2+
1.0
3+
4.0354487481064565e+00 1.1027270790560616e-17 2.5642993008475204e-17
4+
2.0693526054669642e-01 4.1066892997402196e+00 -8.6715682899078028e-18
5+
4.2891472979598610e-01 5.5796885749827474e-01 4.1100061517204542e+00
6+
C Si
7+
4 4
8+
Cartesian
9+
0.03122504 0.15559669 2.1913045
10+
1.93908836 -0.08678864 0.06748919
11+
0.13114716 2.15827511 0.06333341
12+
2.36161952 1.42824405 2.58837618
13+
-0.03895165 0.12197669 0.05496244
14+
1.79528462 2.48830207 -0.55733221
15+
2.11363589 0.09280028 2.0301803
16+
0.19221505 2.16245144 2.07930701

tests/test_perturb.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ class NormalGenerator:
1212
def __init__(self):
1313
self.randn_generator = self.get_randn_generator()
1414
self.rand_generator = self.get_rand_generator()
15+
self.choice_generator = self.get_choice_generator()
1516

1617
def randn(self, number):
1718
return next(self.randn_generator)
1819

1920
def rand(self, number):
2021
return next(self.rand_generator)
2122

23+
def choice(self, total_natoms, pert_natoms, replace):
24+
return next(self.choice_generator)[:pert_natoms]
25+
2226
@staticmethod
2327
def get_randn_generator():
2428
data = np.asarray(
@@ -44,18 +48,26 @@ def get_rand_generator():
4448
[0.23182233, 0.87106847, 0.68728511, 0.94180274, 0.92860453, 0.69191187]
4549
)
4650

51+
@staticmethod
52+
def get_choice_generator():
53+
yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0])
54+
4755

4856
class UniformGenerator:
4957
def __init__(self):
5058
self.randn_generator = self.get_randn_generator()
5159
self.rand_generator = self.get_rand_generator()
60+
self.choice_generator = self.get_choice_generator()
5261

5362
def randn(self, number):
5463
return next(self.randn_generator)
5564

5665
def rand(self, number):
5766
return next(self.rand_generator)
5867

68+
def choice(self, total_natoms, pert_natoms, replace):
69+
return next(self.choice_generator)
70+
5971
@staticmethod
6072
def get_randn_generator():
6173
data = [
@@ -97,18 +109,26 @@ def get_rand_generator():
97109
yield np.asarray(data[count])
98110
count += 1
99111

112+
@staticmethod
113+
def get_choice_generator():
114+
yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0])
115+
100116

101117
class ConstGenerator:
102118
def __init__(self):
103119
self.randn_generator = self.get_randn_generator()
104120
self.rand_generator = self.get_rand_generator()
121+
self.choice_generator = self.get_choice_generator()
105122

106123
def randn(self, number):
107124
return next(self.randn_generator)
108125

109126
def rand(self, number):
110127
return next(self.rand_generator)
111128

129+
def choice(self, total_natoms, pert_natoms, replace):
130+
return next(self.choice_generator)
131+
112132
@staticmethod
113133
def get_randn_generator():
114134
data = np.asarray(
@@ -135,13 +155,18 @@ def get_rand_generator():
135155
[0.01525907, 0.68387374, 0.39768541, 0.55596047, 0.26557088, 0.60883073]
136156
)
137157

158+
@staticmethod
159+
def get_choice_generator():
160+
yield np.asarray([5, 3, 7, 6, 2, 1, 4, 0])
161+
138162

139163
# %%
140164
class TestPerturbNormal(unittest.TestCase, CompSys, IsPBC):
141165
@patch("numpy.random")
142166
def setUp(self, random_mock):
143167
random_mock.rand = NormalGenerator().rand
144168
random_mock.randn = NormalGenerator().randn
169+
random_mock.choice = NormalGenerator().choice
145170
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
146171
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "normal")
147172
self.system_2 = dpdata.System("poscars/POSCAR.SiC.normal", fmt="vasp/poscar")
@@ -153,6 +178,7 @@ class TestPerturbUniform(unittest.TestCase, CompSys, IsPBC):
153178
def setUp(self, random_mock):
154179
random_mock.rand = UniformGenerator().rand
155180
random_mock.randn = UniformGenerator().randn
181+
random_mock.choice = UniformGenerator().choice
156182
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
157183
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "uniform")
158184
self.system_2 = dpdata.System("poscars/POSCAR.SiC.uniform", fmt="vasp/poscar")
@@ -164,11 +190,24 @@ class TestPerturbConst(unittest.TestCase, CompSys, IsPBC):
164190
def setUp(self, random_mock):
165191
random_mock.rand = ConstGenerator().rand
166192
random_mock.randn = ConstGenerator().randn
193+
random_mock.choice = ConstGenerator().choice
167194
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
168195
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "const")
169196
self.system_2 = dpdata.System("poscars/POSCAR.SiC.const", fmt="vasp/poscar")
170197
self.places = 6
171198

172199

200+
class TestPerturbPartAtoms(unittest.TestCase, CompSys, IsPBC):
201+
@patch("numpy.random")
202+
def setUp(self, random_mock):
203+
random_mock.rand = NormalGenerator().rand
204+
random_mock.randn = NormalGenerator().randn
205+
random_mock.choice = NormalGenerator().choice
206+
system_1_origin = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar")
207+
self.system_1 = system_1_origin.perturb(1, 0.05, 0.6, "normal", 0.25)
208+
self.system_2 = dpdata.System("poscars/POSCAR.SiC.partpert", fmt="vasp/poscar")
209+
self.places = 6
210+
211+
173212
if __name__ == "__main__":
174213
unittest.main()

0 commit comments

Comments
 (0)