Skip to content

Commit 90a1ef7

Browse files
committed
Merge branch 'master' of github.com:materialsproject/pymatgen
2 parents 9656ff9 + f82ce1f commit 90a1ef7

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

src/pymatgen/core/structure.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4635,24 +4635,24 @@ def rotate_sites(
46354635

46364636
return self
46374637

4638-
def perturb(self, distance: float, min_distance: float | None = None) -> Self:
4638+
def perturb(self, distance: float, min_distance: float | None = None, seed: int = 0) -> Self:
46394639
"""Perform a random perturbation of the sites in a structure to break
46404640
symmetries. Modifies the structure in place.
46414641
46424642
Args:
46434643
distance (float): Distance in angstroms by which to perturb each site.
46444644
min_distance (None, int, or float): if None, all displacements will
4645-
be equal amplitude. If int or float, perturb each site a
4646-
distance drawn from the uniform distribution between
4647-
'min_distance' and 'distance'.
4645+
be equal amplitude. If int or float, perturb each site a distance drawn
4646+
from the uniform distribution between 'min_distance' and 'distance'.
4647+
seed (int): Seed for the random number generator. Defaults to 0.
46484648
46494649
Returns:
46504650
Structure: self with perturbed sites.
46514651
"""
46524652

46534653
def get_rand_vec():
46544654
# Deal with zero vectors
4655-
rng = np.random.default_rng()
4655+
rng = np.random.default_rng(seed=seed)
46564656
vector = rng.standard_normal(3)
46574657
vnorm = np.linalg.norm(vector)
46584658
dist = distance

tests/core/test_structure.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,22 +1213,42 @@ def test_propertied_structure(self):
12131213
assert dct == struct.as_dict()
12141214

12151215
def test_perturb(self):
1216-
dist = 0.1
1217-
pre_perturbation_sites = self.struct.copy()
1218-
returned = self.struct.perturb(distance=dist)
1219-
assert returned is self.struct
1220-
post_perturbation_sites = self.struct.sites
1221-
1222-
for idx, site in enumerate(pre_perturbation_sites):
1223-
assert site.distance(post_perturbation_sites[idx]) == approx(dist), "Bad perturbation distance"
1224-
1225-
structure2 = pre_perturbation_sites.copy()
1226-
structure2.perturb(distance=dist, min_distance=0)
1227-
post_perturbation_sites2 = structure2.sites
1228-
1229-
for idx, site in enumerate(pre_perturbation_sites):
1230-
assert site.distance(post_perturbation_sites2[idx]) <= dist
1231-
assert site.distance(post_perturbation_sites2[idx]) >= 0
1216+
struct = self.get_structure("Li2O")
1217+
struct_orig = struct.copy()
1218+
struct.perturb(0.1)
1219+
# Ensure all sites were perturbed by a distance of at most 0.1 Angstroms
1220+
for site, site_orig in zip(struct, struct_orig, strict=True):
1221+
cart_dist = site.distance(site_orig)
1222+
# allow 1e-6 to account for numerical precision
1223+
assert cart_dist <= 0.1 + 1e-6, f"Distance {cart_dist} > 0.1"
1224+
1225+
# Test that same seed gives same perturbation
1226+
s1 = self.get_structure("Li2O")
1227+
s2 = self.get_structure("Li2O")
1228+
s1.perturb(0.1, seed=42)
1229+
s2.perturb(0.1, seed=42)
1230+
for site1, site2 in zip(s1, s2, strict=True):
1231+
assert site1.distance(site2) < 1e-7 # should be exactly equal up to numerical precision
1232+
1233+
# Test that different seeds give different perturbations
1234+
s3 = self.get_structure("Li2O")
1235+
s3.perturb(0.1, seed=100)
1236+
any_different = False
1237+
for site1, site3 in zip(s1, s3, strict=True):
1238+
if site1.distance(site3) > 1e-7:
1239+
any_different = True
1240+
break
1241+
assert any_different, "Different seeds should give different perturbations"
1242+
1243+
# Test min_distance
1244+
s4 = self.get_structure("Li2O")
1245+
s4.perturb(0.1, min_distance=0.05, seed=42)
1246+
any_different = False
1247+
for site1, site4 in zip(s1, s4, strict=True):
1248+
if site1.distance(site4) > 1e-7:
1249+
any_different = True
1250+
break
1251+
assert any_different, "Using min_distance should give different perturbations"
12321252

12331253
def test_add_oxidation_state_by_element(self):
12341254
oxidation_states = {"Si": -4}

0 commit comments

Comments
 (0)