@@ -1213,22 +1213,42 @@ def test_propertied_structure(self):
1213
1213
assert dct == struct .as_dict ()
1214
1214
1215
1215
def test_perturb (self ):
1216
- dist = 0.1
1217
- pre_perturbation_sites = self .struct .copy ()
1218
- returned = self .struct .perturb (distance = dist )
1219
- assert returned is self .struct
1220
- post_perturbation_sites = self .struct .sites
1221
-
1222
- for idx , site in enumerate (pre_perturbation_sites ):
1223
- assert site .distance (post_perturbation_sites [idx ]) == approx (dist ), "Bad perturbation distance"
1224
-
1225
- structure2 = pre_perturbation_sites .copy ()
1226
- structure2 .perturb (distance = dist , min_distance = 0 )
1227
- post_perturbation_sites2 = structure2 .sites
1228
-
1229
- for idx , site in enumerate (pre_perturbation_sites ):
1230
- assert site .distance (post_perturbation_sites2 [idx ]) <= dist
1231
- assert site .distance (post_perturbation_sites2 [idx ]) >= 0
1216
+ struct = self .get_structure ("Li2O" )
1217
+ struct_orig = struct .copy ()
1218
+ struct .perturb (0.1 )
1219
+ # Ensure all sites were perturbed by a distance of at most 0.1 Angstroms
1220
+ for site , site_orig in zip (struct , struct_orig , strict = True ):
1221
+ cart_dist = site .distance (site_orig )
1222
+ # allow 1e-6 to account for numerical precision
1223
+ assert cart_dist <= 0.1 + 1e-6 , f"Distance { cart_dist } > 0.1"
1224
+
1225
+ # Test that same seed gives same perturbation
1226
+ s1 = self .get_structure ("Li2O" )
1227
+ s2 = self .get_structure ("Li2O" )
1228
+ s1 .perturb (0.1 , seed = 42 )
1229
+ s2 .perturb (0.1 , seed = 42 )
1230
+ for site1 , site2 in zip (s1 , s2 , strict = True ):
1231
+ assert site1 .distance (site2 ) < 1e-7 # should be exactly equal up to numerical precision
1232
+
1233
+ # Test that different seeds give different perturbations
1234
+ s3 = self .get_structure ("Li2O" )
1235
+ s3 .perturb (0.1 , seed = 100 )
1236
+ any_different = False
1237
+ for site1 , site3 in zip (s1 , s3 , strict = True ):
1238
+ if site1 .distance (site3 ) > 1e-7 :
1239
+ any_different = True
1240
+ break
1241
+ assert any_different , "Different seeds should give different perturbations"
1242
+
1243
+ # Test min_distance
1244
+ s4 = self .get_structure ("Li2O" )
1245
+ s4 .perturb (0.1 , min_distance = 0.05 , seed = 42 )
1246
+ any_different = False
1247
+ for site1 , site4 in zip (s1 , s4 , strict = True ):
1248
+ if site1 .distance (site4 ) > 1e-7 :
1249
+ any_different = True
1250
+ break
1251
+ assert any_different , "Using min_distance should give different perturbations"
1232
1252
1233
1253
def test_add_oxidation_state_by_element (self ):
1234
1254
oxidation_states = {"Si" : - 4 }
0 commit comments