Skip to content

Commit 9d3d82c

Browse files
[Feature] Added Pure Random Algo to OrderDisorderedStructureTransformation (#4236)
* occ_tol * random transformation * pre-commit auto-fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 300a33e commit 9d3d82c

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

src/pymatgen/transformations/standard_transformations.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymatgen.transformations.transformation_abc import AbstractTransformation
2424

2525
if TYPE_CHECKING:
26+
from numpy.random import Generator
2627
from typing_extensions import Self
2728

2829
from pymatgen.core.sites import PeriodicSite
@@ -451,6 +452,7 @@ class OrderDisorderedStructureTransformation(AbstractTransformation):
451452
ALGO_FAST = 0
452453
ALGO_COMPLETE = 1
453454
ALGO_BEST_FIRST = 2
455+
ALGO_RANDOM = -1
454456

455457
def __init__(self, algo=ALGO_FAST, symmetrized_structures=False, no_oxi_states=False):
456458
"""
@@ -467,7 +469,9 @@ def __init__(self, algo=ALGO_FAST, symmetrized_structures=False, no_oxi_states=F
467469
self.no_oxi_states = no_oxi_states
468470
self.symmetrized_structures = symmetrized_structures
469471

470-
def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False) -> Structure:
472+
def apply_transformation(
473+
self, structure: Structure, return_ranked_list: bool | int = False, occ_tol=0.25
474+
) -> Structure:
471475
"""For this transformation, the apply_transformation method will return
472476
only the ordered structure with the lowest Ewald energy, to be
473477
consistent with the method signature of the other transformations.
@@ -478,6 +482,9 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
478482
structure: Oxidation state decorated disordered structure to order
479483
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures
480484
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
485+
occ_tol (float): Occupancy tolerance. If the total occupancy of a group is within this value
486+
of an integer, it will be rounded to that integer otherwise raise a ValueError.
487+
Defaults to 0.25.
481488
482489
Returns:
483490
Depending on returned_ranked list, either a transformed structure
@@ -529,14 +536,25 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
529536
# generate the list of manipulations and input structure
530537
struct = Structure.from_sites(structure)
531538

539+
# We will first create an initial ordered structure by filling all sites
540+
# with the species that has the highest oxidation state (initial_sp)
541+
# replacing all other species on a given site.
542+
# then, we process a list of manipulations to get the final structure.
543+
# The manipulations are of the format:
544+
# [oxi_ratio, 1, [0,1,2,3], Li+]
545+
# which means -- Place 1 Li+ in any of these 4 sites
546+
# the oxi_ratio is the ratio of the oxidation state of the species to
547+
# the initial species. This is used to determine the energy of the
548+
# manipulation in the EwaldMinimizer, but is not used in the purely random
549+
# algorithm.
532550
manipulations = []
533551
for group in equivalent_sites:
534552
total_occupancy = dict(
535553
sum((structure[idx].species for idx in group), Composition()).items() # type: ignore[attr-defined]
536554
)
537555
# round total occupancy to possible values
538556
for key, val in total_occupancy.items():
539-
if abs(val - round(val)) > 0.25:
557+
if abs(val - round(val)) > occ_tol:
540558
raise ValueError("Occupancy fractions not consistent with size of unit cell")
541559
total_occupancy[key] = round(val)
542560
# start with an ordered structure
@@ -555,6 +573,16 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
555573
if empty > 0.5:
556574
manipulations.append([0, empty, list(group), None])
557575

576+
if self.algo == self.ALGO_RANDOM:
577+
rand_structures = get_randomly_manipulated_structures(
578+
struct=struct, manipulations=manipulations, n_return=n_to_return
579+
)
580+
if return_ranked_list:
581+
return [
582+
{"energy": 0.0, "energy_above_minimum": 0.0, "structure": s} for s in rand_structures[:n_to_return]
583+
]
584+
return rand_structures[0]
585+
558586
matrix = EwaldSummation(struct).total_energy_matrix
559587
ewald_m = EwaldMinimizer(matrix, manipulations, n_to_return, self.algo)
560588

@@ -891,3 +919,82 @@ def apply_transformation(self, structure):
891919

892920
def __repr__(self):
893921
return "ScaleToRelaxedTransformation"
922+
923+
924+
def _sample_random_manipulation(manipulation, rng, manipulated) -> list[tuple[int, SpeciesLike]]:
925+
"""Sample a single random manipulation.
926+
927+
Each manipulation is given in the form of a tuple
928+
`(oxi_ratio, nsites, indices, sp)` where:
929+
Which means choose nsites from the list of indices and replace them
930+
With the species `sp`.
931+
"""
932+
_, nsites, indices, sp = manipulation
933+
maniped_indices = [i for i, _ in manipulated]
934+
allowed_sites = [i for i in indices if i not in maniped_indices]
935+
if len(allowed_sites) < nsites:
936+
raise RuntimeError(
937+
"No valid manipulations possible. "
938+
f" You have already applied a manipulation to each site in this group {indices}"
939+
)
940+
sampled_sites = rng.choice(allowed_sites, nsites, replace=False).tolist()
941+
sampled_sites.sort()
942+
return [(i, sp) for i in sampled_sites]
943+
944+
945+
def _get_manipulation(manipulations: list, rng: Generator, max_attempts, seen: set[tuple]) -> tuple:
946+
"""Apply each manipulation."""
947+
for _ in range(max_attempts):
948+
manipulated: list[tuple] = []
949+
for manip_ in manipulations:
950+
new_manips = _sample_random_manipulation(manip_, rng, manipulated)
951+
manipulated += new_manips
952+
tm_ = tuple(manipulated)
953+
if tm_ not in seen:
954+
return tm_
955+
raise RuntimeError(
956+
"Could not apply manipulations to structure"
957+
"this is likely because you have already applied all the possible manipulations"
958+
)
959+
960+
961+
def _apply_manip(struct, manipulations) -> Structure:
962+
"""Apply manipulations to a structure."""
963+
struct_copy = struct.copy()
964+
rm_indices = []
965+
for manip in manipulations:
966+
idx, sp = manip
967+
if sp is None:
968+
rm_indices.append(idx)
969+
else:
970+
struct_copy.replace(idx, sp)
971+
struct_copy.remove_sites(rm_indices)
972+
return struct_copy
973+
974+
975+
def get_randomly_manipulated_structures(
976+
struct: Structure, manipulations: list, seed=None, n_return: int = 1
977+
) -> list[Structure]:
978+
"""Get a structure with random manipulations applied.
979+
980+
Args:
981+
struct: Input structure
982+
manipulations: List of manipulations to apply
983+
seed: Seed for random number generator
984+
n_return: Number of structures to return
985+
986+
Returns:
987+
List of structures with manipulations applied.
988+
"""
989+
rng = np.random.default_rng(seed)
990+
seen: set[tuple] = set()
991+
sampled_manips = []
992+
993+
for _ in range(n_return):
994+
manip_ = _get_manipulation(manipulations, rng, 1000, seen)
995+
seen.add(manip_)
996+
sampled_manips.append(manip_)
997+
output_structs = []
998+
for manip_ in sampled_manips:
999+
output_structs.append(_apply_manip(struct, manip_))
1000+
return output_structs

tests/transformations/test_standard_transformations.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,32 @@ def test_best_first(self):
401401
output = trafo.apply_transformation(struct, return_ranked_list=3)
402402
assert output[0]["energy"] == approx(-234.57813667648315, abs=1e-4)
403403

404+
def test_random_sample(self):
405+
struc_str = (
406+
"3.333573 0.000000 1.924639\n"
407+
"1.111191 3.142924 1.924639\n"
408+
"0.000000 0.000000 3.849278\n"
409+
"1.0 0.0 0.0\n"
410+
"0.0 1.0 0.0\n"
411+
"0.0 0.0 1.0\n"
412+
"0.875000 0.875000 0.875000 Si=1\n"
413+
"0.125000 0.125000 0.125000 Si=1"
414+
)
415+
si = Structure.from_str(struc_str, fmt="mcsqs")
416+
struct = si * [3, 2, 1]
417+
struct.replace(0, {"Fe": 0.5, "Ni": 0.5})
418+
struct.replace(1, {"Fe": 0.5, "Ni": 0.5})
419+
trafo = OrderDisorderedStructureTransformation(
420+
algo=OrderDisorderedStructureTransformation.ALGO_RANDOM, no_oxi_states=True
421+
)
422+
output = trafo.apply_transformation(struct * [2, 2, 2], return_ranked_list=3)
423+
assert len(output) == 3
424+
for entry in output:
425+
assert set(entry.keys()) == {"structure", "energy", "energy_above_minimum"}
426+
427+
output = trafo.apply_transformation(struct * [2, 2, 2], return_ranked_list=False)
428+
assert output.composition.reduced_formula == struct.composition.reduced_formula
429+
404430

405431
class TestPrimitiveCellTransformation:
406432
def test_apply_transformation(self):

0 commit comments

Comments
 (0)