Skip to content

Commit 3ba6384

Browse files
pre-commit auto-fixes
1 parent 8cb1729 commit 3ba6384

File tree

2 files changed

+39
-33
lines changed

2 files changed

+39
-33
lines changed

pymatgen/analysis/fstar/fstar.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
Yin, L. et al.Thermodynamics of Antisite Defects in Layered NMC Cathodes: Systematic Insights from High-Precision Powder
77
Diffraction Analyses Chem. Mater 2020 32 (3), 1002-1010. 10.1021/acs.chemmater.9b03646
88
"""
9+
from __future__ import annotations
910

1011
import os
12+
from typing import TYPE_CHECKING
13+
1114
import numpy as np
1215
import pandas as pd
1316
import plotly.express as px
14-
from typing import Callable
15-
from pymatgen.core.periodic_table import Element
16-
from pymatgen.core.structure import Structure
17+
18+
if TYPE_CHECKING:
19+
from pymatgen.core.periodic_table import Element
20+
from pymatgen.core.structure import Structure
1721

1822
# Load in the neutron form factors
1923
with open(os.path.join(os.path.dirname(__file__), "neutron_factors.csv")) as f:
@@ -25,7 +29,8 @@ class FStarDiagram:
2529
"""
2630
Take a list of symmetrized structure objects and use them to generate an f* phase diagram.
2731
"""
28-
def __init__(self, structures: list[Structure], scattering_type: str='X-ray'):
32+
33+
def __init__(self, structures: list[Structure], scattering_type: str = "X-ray"):
2934
"""
3035
Initialize the f* diagram generator with the list of structures and scattering type.
3136
@@ -39,26 +44,22 @@ def __init__(self, structures: list[Structure], scattering_type: str='X-ray'):
3944
# check if the input structures list is valid
4045
for ind, struct in enumerate(structures):
4146
try:
42-
if struct.equivalent_indices == structures[ind-1].equivalent_indices:
47+
if struct.equivalent_indices == structures[ind - 1].equivalent_indices:
4348
continue
4449
else:
45-
raise ValueError(
46-
"All structues must only vary in occupancy."
47-
)
50+
raise ValueError("All structues must only vary in occupancy.")
4851
except AttributeError:
49-
raise AttributeError(
50-
"Must use symmeteized structure objects"
51-
)
52+
raise AttributeError("Must use symmeteized structure objects")
5253
self._structures = structures
5354
self._scatter = scattering_type
54-
self._scatter_dict = {'X-ray':self.xray_scatter,'Neutron':self.neutron_scatter}
55+
self._scatter_dict = {"X-ray": self.xray_scatter, "Neutron": self.neutron_scatter}
5556
self.site_labels = self._get_site_labels()
5657
self.fstar_coords = self._get_fstar_coords()
5758
self.set_plot_list([self.site_labels[0], self.site_labels[1], self.site_labels[2]])
5859
self.make_plot()
5960
print("The labels for this structure's unique sites are")
6061
print(self.site_labels)
61-
62+
6263
def combine_sites(self, site_lists: list[list[str]]) -> None:
6364
"""
6465
Many structures have more than three sites. If this is the case you may want to
@@ -70,12 +71,11 @@ def combine_sites(self, site_lists: list[list[str]]) -> None:
7071
for combo in site_lists:
7172
for site in combo:
7273
if site not in self.site_labels:
73-
raise ValueError(
74-
"All sites must be in the site_labels list"
75-
)
74+
raise ValueError("All sites must be in the site_labels list")
7675
self.fstar_coords[str(combo)] = sum([self.fstar_coords[site] for site in combo])
7776
if str(combo) not in self.site_labels:
7877
self.site_labels.append(str(combo))
78+
7979
def set_plot_list(self, site_list: list[str]) -> None:
8080
"""
8181
set the list of sites to plot and the order to plot them in.
@@ -85,19 +85,20 @@ def set_plot_list(self, site_list: list[str]) -> None:
8585
"""
8686
for site in site_list:
8787
if site not in self.site_labels:
88-
raise ValueError(
89-
"All sites must be in the site_labels list"
90-
)
88+
raise ValueError("All sites must be in the site_labels list")
9189
self.plot_list = site_list
90+
9291
def make_plot(self, **kwargs):
9392
"""
9493
Makes a plotly express scatter_ternary plot useing the fstar_coords dataframe and the
9594
sites in plot list.
9695
Args:
9796
**kwargs: this can be any argument that the scatter_ternary fucntion can use.
9897
"""
99-
self.plot = px.scatter_ternary(data_frame=self.fstar_coords, a=self.plot_list[0], b=self.plot_list[1],
100-
c=self.plot_list[2], **kwargs)
98+
self.plot = px.scatter_ternary(
99+
data_frame=self.fstar_coords, a=self.plot_list[0], b=self.plot_list[1], c=self.plot_list[2], **kwargs
100+
)
101+
101102
def _get_site_labels(self):
102103
"""
103104
Generates unique site labels based on composition, order, and symetry equivalence in the structure object.
@@ -132,8 +133,10 @@ def _get_site_labels(self):
132133
"""
133134
site_labels = []
134135
for site in self._structures[0].equivalent_indices:
135-
site_labels.append(str(self._structures[0][site[0]].frac_coords) + \
136-
[str(sp) for sp, _ in self._structures[0][site[0]].species.items()][0])
136+
site_labels.append(
137+
str(self._structures[0][site[0]].frac_coords)
138+
+ next(str(sp) for sp, _ in self._structures[0][site[0]].species.items())
139+
)
137140
return site_labels
138141

139142
def _get_fstar_coords(self):
@@ -152,31 +155,32 @@ def _get_fstar_coords(self):
152155
for sp, occ in elements_and_occupancies:
153156
# ind1 and ind2 are added in case someone wants to make a custom scatter function
154157
# that uses information in the structure object
155-
f_occ = self._scatter_dict[self._scatter](sp, occ, ind1, ind2)
158+
f_occ = self._scatter_dict[self._scatter](sp, occ, ind1, ind2)
156159
occ_f_list.append(f_occ)
157160
fstar = np.absolute(mult * sum(occ_f_list))
158161
fstar_df.loc[0][column[0]] = round(float(fstar), 4)
159162
tot = sum(sum(list(fstar_df.values)))
160163
fstar_df = pd.DataFrame(columns=self.site_labels, data=[fs / tot for fs in list(fstar_df.values)])
161164
fstar_df_full = pd.concat([fstar_df_full, fstar_df], ignore_index=True)
162165
return fstar_df_full
166+
163167
def xray_scatter(self, el: Element, occ: float, i1: int, i2: int) -> float:
164168
"""
165169
X-ray scattering function. i2 and i2 are unused.
166170
"""
167-
f_occ = el.Z * occ
168-
return f_occ
171+
return el.Z * occ
172+
169173
def neutron_scatter(self, el: Element, occ: float, i1: int, i2: int) -> float:
170174
"""
171175
Neutron scattering function. i2 and i2 are unused.
172176
"""
173-
for i, n in enumerate(NEUTRON_SCATTER_DF['Isotope'].values):
177+
for i, n in enumerate(NEUTRON_SCATTER_DF["Isotope"].values):
174178
if hasattr(el, "element"):
175179
if n == str(el.element):
176-
f_occ = float(NEUTRON_SCATTER_DF.loc[i]['Coh b']) * occ
180+
f_occ = float(NEUTRON_SCATTER_DF.loc[i]["Coh b"]) * occ
177181
break
178182
else:
179183
if n == str(el):
180-
f_occ = float(NEUTRON_SCATTER_DF.loc[i]['Coh b']) * occ
184+
f_occ = float(NEUTRON_SCATTER_DF.loc[i]["Coh b"]) * occ
181185
break
182-
return f_occ
186+
return f_occ

tests/analysis/fstar/test_fstar.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
class Test_FStarDiagram(PymatgenTest):
1111
def setUp(self):
12-
self.cif_gz = tarfile.open(f"{TEST_FILES_DIR}/fstar/fstar.tar.gz","r")
12+
self.cif_gz = tarfile.open(f"{TEST_FILES_DIR}/fstar/fstar.tar.gz", "r")
1313
self.struct_list = [
14-
CifParser.from_str(self.cif_gz.extractfile(file).read().decode('utf-8')).get_structures(
15-
primitive=False, symmetrized=True, check_occu=False)[0] for file in self.cif_gz.getnames()
14+
CifParser.from_str(self.cif_gz.extractfile(file).read().decode("utf-8")).get_structures(
15+
primitive=False, symmetrized=True, check_occu=False
16+
)[0]
17+
for file in self.cif_gz.getnames()
1618
]
1719
self.fstar = FStarDiagram(structures=self.struct_list)
1820

0 commit comments

Comments
 (0)