Skip to content

Commit e9cea13

Browse files
make edits
compressed the cifs into fstar.tar.gz. Implemented type hints. reworked to simplify and run better.
1 parent 6314751 commit e9cea13

24 files changed

+106
-3809
lines changed

pymatgen/analysis/fstar/fstar.py

Lines changed: 96 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
atomic form factors. Rev. Sci. Instrum. 89, 093002 (2018). 10.1063/1.5044555
66
Yin, L. et al.Thermodynamics of Antisite Defects in Layered NMC Cathodes: Systematic Insights from High-Precision Powder
77
Diffraction Analyses Chem. Mater 2020 32 (3), 1002-1010. 10.1021/acs.chemmater.9b03646
8-
98
"""
10-
from __future__ import annotations
119

1210
import os
13-
1411
import numpy as np
1512
import pandas as pd
1613
import plotly.express as px
14+
from typing import Callable
15+
from pymatgen.core.periodic_table import Element
16+
from pymatgen.core.structure import Structure
1717

1818
# Load in the neutron form factors
1919
with open(os.path.join(os.path.dirname(__file__), "neutron_factors.csv")) as f:
@@ -23,64 +23,84 @@
2323

2424
class FStarDiagram:
2525
"""
26-
Take a list of structure objects and/or cifs and use them to generate an f* phase diagram.
26+
Take a list of symmetrized structure objects and use them to generate an f* phase diagram.
2727
"""
28-
29-
def __init__(self, structures, scattering_type="X-ray", custom_scatter=None):
28+
def __init__(self, structures: list[Structure], scattering_type: str='X-ray'):
3029
"""
3130
Initialize the f* diagram generator with the list of structures and scattering type.
3231
3332
Args:
3433
structures(list): List of structure objects to use in the diagram. These MUST be symmetrized structure
35-
objects.
34+
objects. All structures must only vary in the occupancy of the sites.
3635
scattering_type(str): Type of scattering to use in the f* calculation. Defaults to 'X-ray'
3736
which uses the atomic number as the scattering factor. 'Neutron' is a built in scattering
38-
type which uses neutron scattering factors. 'Custom' allows the user to supplement their
39-
own calculation with any set of scattering factors.
40-
custom_scatter(function): when using custom scattering set this equal to a global varialble that is equal
41-
to the custom scattering function.
37+
type which uses neutron scattering factors.
4238
"""
43-
39+
# check if the input structures list is valid
40+
for ind, struct in enumerate(structures):
41+
try:
42+
if struct.equivalent_indices == structures[ind-1].equivalent_indices:
43+
continue
44+
else:
45+
raise ValueError(
46+
"All structues must only vary in occupancy."
47+
)
48+
except AttributeError:
49+
raise AttributeError(
50+
"Must use symmeteized structure objects"
51+
)
4452
self._structures = structures
4553
self._scatter = scattering_type
46-
self._custscat = custom_scatter
47-
self._equiv_inds = [struct.equivalent_indices for struct in self._structures]
48-
self.site_labels = self.get_site_labels()
49-
self.coords = self.get_fstar_coords()
50-
self.plot = px.scatter_ternary(
51-
data_frame=self.coords, a=self.site_labels[0], b=self.site_labels[1], c=self.site_labels[2]
52-
)
54+
self._scatter_dict = {'X-ray':self.xray_scatter,'Neutron':self.neutron_scatter}
55+
self.site_labels = self._get_site_labels()
56+
self.fstar_coords = self._get_fstar_coords()
57+
self.set_plot_list([self.site_labels[0], self.site_labels[1], self.site_labels[2]])
58+
self.make_plot()
5359
print("The labels for this structure's unique sites are")
5460
print(self.site_labels)
55-
56-
def edit_fstar_diagram(self, combine_list=False, plot_list=False, **kwargs):
61+
62+
def combine_sites(self, site_lists: list[list[str]]) -> None:
5763
"""
58-
Edit the plot of the f* diagram using plotly express.
59-
64+
Many structures have more than three sites. If this is the case you may want to
65+
add some sites together to make a psudo-site.
6066
Args:
61-
combine_list(list): This is a list of lists which indicates what unique sites need to be combined to make
62-
the plot ternary.
63-
plot_list(list): This is a list that indicates what unique sites or combined sites to plot and what order to
64-
plot them in.
65-
kwargs: use this to add any other arguments from scatter_ternary .
66-
"""
67-
if combine_list:
68-
for combo in combine_list:
69-
self.coords[str(combo)] = sum([self.coords[site] for site in combo])
70-
if str(combo) not in self.site_labels:
71-
self.site_labels.append(str(combo))
72-
if plot_list:
73-
self.plot = px.scatter_ternary(
74-
data_frame=self.coords, a=plot_list[0], b=plot_list[1], c=plot_list[2], **kwargs
75-
)
76-
else:
77-
self.plot = px.scatter_ternary(
78-
data_frame=self.coords, a=self.site_labels[0], b=self.site_labels[1], c=self.site_labels[2], **kwargs
79-
)
80-
81-
def get_site_labels(self):
67+
site_lists(list): A list of lists of site lable strings. This allows you to combine
68+
more than one set of sites at once.
69+
"""
70+
for combo in site_lists:
71+
for site in combo:
72+
if site not in self.site_labels:
73+
raise ValueError(
74+
"All sites must be in the site_labels list"
75+
)
76+
self.fstar_coords[str(combo)] = sum([self.fstar_coords[site] for site in combo])
77+
if str(combo) not in self.site_labels:
78+
self.site_labels.append(str(combo))
79+
def set_plot_list(self, site_list: list[str]) -> None:
80+
"""
81+
set the list of sites to plot and the order to plot them in.
82+
Args:
83+
site_list(list): A list of site lable strings. Index 0 goes on the top of the
84+
plot, index 1 goes on the bottom left, and index 2 goes on the bottom right.
85+
"""
86+
for site in site_list:
87+
if site not in self.site_labels:
88+
raise ValueError(
89+
"All sites must be in the site_labels list"
90+
)
91+
self.plot_list = site_list
92+
def make_plot(self, **kwargs):
8293
"""
83-
Generates unique site labels based on composition, order, and symmetry equivalence in the structure object.
94+
Makes a plotly express scatter_ternary plot useing the fstar_coords dataframe and the
95+
sites in plot list.
96+
Args:
97+
**kwargs: this can be any argument that the scatter_ternary fucntion can use.
98+
"""
99+
self.plot = px.scatter_ternary(data_frame=self.fstar_coords, a=self.plot_list[0], b=self.plot_list[1],
100+
c=self.plot_list[2], **kwargs)
101+
def _get_site_labels(self):
102+
"""
103+
Generates unique site labels based on composition, order, and symetry equivalence in the structure object.
84104
Ex:
85105
Structure Summary
86106
Lattice
@@ -110,58 +130,53 @@ def get_site_labels(self):
110130
(0.0000, 0.0000, 7.1375) [0.0000, 0.0000, 0.5000]
111131
'[0. 0. 0.25]O' - PeriodicSite: O (0.0000, 0.0000, 3.5688) [0.0000, 0.0000, 0.2500]
112132
"""
133+
site_labels = []
134+
for site in self._structures[0].equivalent_indices:
135+
site_labels.append(str(self._structures[0][site[0]].frac_coords) + \
136+
[str(sp) for sp, _ in self._structures[0][site[0]].species.items()][0])
137+
return site_labels
113138

114-
site_labels_fin = []
115-
for ind1, struct in enumerate(self._equiv_inds):
116-
site_labels = []
117-
for _ind2, site in enumerate(struct):
118-
label = str(self._structures[ind1][site[0]].frac_coords) + next(
119-
str(sp) for sp, occ in self._structures[ind1][site[0]].species.items()
120-
)
121-
if label not in site_labels:
122-
site_labels.append(label)
123-
if len(site_labels) > len(site_labels_fin):
124-
site_labels_fin = site_labels
125-
return site_labels_fin
126-
127-
def get_fstar_coords(self):
139+
def _get_fstar_coords(self):
128140
"""
129141
Calculate the f* coordinates for the list of structures.
130142
"""
131-
132143
fstar_df_full = pd.DataFrame(columns=self.site_labels)
133-
134-
for ind1, struct in enumerate(self._equiv_inds):
144+
for ind1, struct in enumerate(self._structures):
135145
fstar_df = pd.DataFrame(columns=self.site_labels, data=[[0.0 for i in self.site_labels]])
136-
for ind2, site in enumerate(struct):
146+
for ind2, site in enumerate(struct.equivalent_indices):
137147
occ_f_list = []
138148
mult = len(site)
139149
site_frac_coord = str(self._structures[ind1][site[0]].frac_coords)
140150
column = [label for label in self.site_labels if site_frac_coord in label]
141151
elements_and_occupancies = self._structures[ind1][site[0]].species.items()
142152
for sp, occ in elements_and_occupancies:
143-
if self._scatter == "X-ray":
144-
f_occ = sp.Z * occ
145-
if self._scatter == "Neutron":
146-
for i, n in enumerate(NEUTRON_SCATTER_DF["Isotope"].values):
147-
if hasattr(sp, "element"):
148-
if n == str(sp.element):
149-
f_occ = float(NEUTRON_SCATTER_DF.loc[i]["Coh b"]) * occ
150-
break
151-
else:
152-
if n == str(sp):
153-
f_occ = float(NEUTRON_SCATTER_DF.loc[i]["Coh b"]) * occ
154-
break
155-
if self._scatter == "Custom":
156-
if hasattr(sp, "element"):
157-
f_occ = self._custscat(str(sp.element), occ, ind1, ind2)
158-
else:
159-
f_occ = self._custscat(str(sp), occ, ind1, ind2)
153+
# ind1 and ind2 are added in case someone wants to make a custom scatter function
154+
# that uses information in the structure object
155+
f_occ = self._scatter_dict[self._scatter](sp, occ, ind1, ind2)
160156
occ_f_list.append(f_occ)
161-
162157
fstar = np.absolute(mult * sum(occ_f_list))
163158
fstar_df.loc[0][column[0]] = round(float(fstar), 4)
164159
tot = sum(sum(list(fstar_df.values)))
165160
fstar_df = pd.DataFrame(columns=self.site_labels, data=[fs / tot for fs in list(fstar_df.values)])
166161
fstar_df_full = pd.concat([fstar_df_full, fstar_df], ignore_index=True)
167162
return fstar_df_full
163+
def xray_scatter(self, el: Element, occ: float, i1: int, i2: int) -> float:
164+
"""
165+
X-ray scattering function. i2 and i2 are unused.
166+
"""
167+
f_occ = el.Z * occ
168+
return f_occ
169+
def neutron_scatter(self, el: Element, occ: float, i1: int, i2: int) -> float:
170+
"""
171+
Neutron scattering function. i2 and i2 are unused.
172+
"""
173+
for i, n in enumerate(NEUTRON_SCATTER_DF['Isotope'].values):
174+
if hasattr(el, "element"):
175+
if n == str(el.element):
176+
f_occ = float(NEUTRON_SCATTER_DF.loc[i]['Coh b']) * occ
177+
break
178+
else:
179+
if n == str(el):
180+
f_occ = float(NEUTRON_SCATTER_DF.loc[i]['Coh b']) * occ
181+
break
182+
return f_occ

tests/analysis/fstar/test_fstar.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import os
3+
import tarfile
44

55
from pymatgen.analysis.fstar.fstar import FStarDiagram
66
from pymatgen.io.cif import CifParser
@@ -9,29 +9,28 @@
99

1010
class Test_FStarDiagram(PymatgenTest):
1111
def setUp(self):
12-
self.cif_list = [file for file in os.listdir(f"{TEST_FILES_DIR}/analysis/fstar") if file.endswith(".cif")]
12+
self.cif_gz = tarfile.open(f"{TEST_FILES_DIR}/fstar/fstar.tar.gz","r")
1313
self.struct_list = [
14-
CifParser(f"{TEST_FILES_DIR}/analysis/fstar/" + file).get_structures(
15-
primitive=False, symmetrized=True, check_occu=False
16-
)[0]
17-
for file in self.cif_list
14+
CifParser.from_str(self.cif_gz.extractfile(file).read().decode('utf-8')).get_structures(
15+
primitive=False, symmetrized=True, check_occu=False)[0] for file in self.cif_gz.getnames()
1816
]
1917
self.fstar = FStarDiagram(structures=self.struct_list)
2018

2119
def test_edit_fstar_diagram(self):
2220
assert self.fstar.site_labels == ["[0. 0. 0.]Li", "[0. 0. 0.5]Co", "[0. 0. 0.25]O"]
2321
new = FStarDiagram(structures=self.struct_list)
2422
assert self.fstar.plot == new.plot
25-
new.edit_fstar_diagram(combine_list=[["[0. 0. 0.5]Co", "[0. 0. 0.]Li"]])
23+
new.combine_sites(site_lists=[["[0. 0. 0.5]Co", "[0. 0. 0.]Li"]])
2624
assert new.site_labels == [
2725
"[0. 0. 0.]Li",
2826
"[0. 0. 0.5]Co",
2927
"[0. 0. 0.25]O",
3028
"['[0. 0. 0.5]Co', '[0. 0. 0.]Li']",
3129
]
32-
assert list(new.coords["['[0. 0. 0.5]Co', '[0. 0. 0.]Li']"].to_numpy()) == list(
33-
self.fstar.coords["[0. 0. 0.]Li"].to_numpy() + self.fstar.coords["[0. 0. 0.5]Co"].to_numpy()
30+
assert list(new.fstar_coords["['[0. 0. 0.5]Co', '[0. 0. 0.]Li']"].to_numpy()) == list(
31+
self.fstar.fstar_coords["[0. 0. 0.]Li"].to_numpy() + self.fstar.fstar_coords["[0. 0. 0.5]Co"].to_numpy()
3432
)
35-
assert self.fstar.plot == new.plot
36-
new.edit_fstar_diagram(plot_list=["[0. 0. 0.5]Co", "[0. 0. 0.25]O", "[0. 0. 0.]Li"])
33+
new.set_plot_list(site_list=["[0. 0. 0.5]Co", "[0. 0. 0.25]O", "[0. 0. 0.]Li"])
34+
assert self.fstar.plot_list != new.plot_list
35+
new.make_plot()
3736
assert self.fstar.plot != new.plot

0 commit comments

Comments
 (0)