Skip to content

Commit 0197856

Browse files
committed
More cleanup.
1 parent bbcd3d7 commit 0197856

File tree

5 files changed

+43
-42
lines changed

5 files changed

+43
-42
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ exclude_also = [
301301
ignore_missing_imports = true
302302
namespace_packages = true
303303
no_implicit_optional = false
304-
disable_error_code = ["annotation-unchecked", "override", "operator", "attr-defined", "union-attr", "misc", "call-overload"]
304+
disable_error_code = ["annotation-unchecked", "override", "operator", "attr-defined", "union-attr", "misc", "call-overload", "index"]
305305
exclude = ['src/pymatgen/analysis', 'src/pymatgen/io/cp2k', 'src/pymatgen/io/lammps']
306306
plugins = ["numpy.typing.mypy_plugin"]
307307

@@ -311,7 +311,7 @@ ignore_missing_imports = true
311311

312312
[tool.codespell]
313313
# TODO: un-ignore "ist/nd/ot/ontop/CoO" once support file-level ignore with pattern
314-
ignore-words-list = """Nd, Te, titel, Mater,
314+
ignore-words-list = """Nd, Te, titel, Mater, nax,
315315
Hart, Lew, Rute, atomate,
316316
ist, nd, ot, ontop, CoO
317317
"""

src/pymatgen/phonon/bandstructure.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def get_reasonable_repetitions(n_atoms: int) -> tuple[int, int, int]:
3535
return 1, 1, 1
3636

3737

38-
def eigenvectors_from_displacements(disp: np.ndarray, masses: np.ndarray) -> np.ndarray:
38+
def eigenvectors_from_displacements(disp: ArrayLike, masses: ArrayLike) -> np.ndarray:
3939
"""Calculate the eigenvectors from the atomic displacements."""
40-
return np.einsum("nax,a->nax", disp, masses**0.5) # codespell:ignore nax
40+
return np.einsum("nax,a->nax", disp, masses**0.5) # type:ignore[arg-type]
4141

4242

4343
def estimate_band_connection(prev_eigvecs, eigvecs, prev_band_order) -> list[int]:
@@ -67,12 +67,12 @@ class PhononBandStructure(MSONable):
6767

6868
def __init__(
6969
self,
70-
qpoints: Sequence[Kpoint],
70+
qpoints: ArrayLike,
7171
frequencies: ArrayLike,
7272
lattice: Lattice,
73-
nac_frequencies: Sequence[Sequence] | None = None,
73+
nac_frequencies: ArrayLike | None = None,
7474
eigendisplacements: ArrayLike = None,
75-
nac_eigendisplacements: Sequence[Sequence] | None = None,
75+
nac_eigendisplacements: ArrayLike | None = None,
7676
labels_dict: dict | None = None,
7777
coords_are_cartesian: bool = False,
7878
structure: Structure | None = None,
@@ -126,14 +126,14 @@ def __init__(
126126
if np.linalg.norm(q_pt - np.array(labels_dict[key])) < 0.0001:
127127
label = key
128128
self.labels_dict[label] = Kpoint(
129-
q_pt,
129+
q_pt, # type:ignore[arg-type]
130130
lattice,
131131
label=label,
132132
coords_are_cartesian=coords_are_cartesian,
133133
)
134134
self.qpoints += [
135135
Kpoint(
136-
q_pt,
136+
q_pt, # type:ignore[arg-type]
137137
lattice,
138138
label=label,
139139
coords_are_cartesian=coords_are_cartesian,
@@ -148,10 +148,10 @@ def __init__(
148148
self.nac_eigendisplacements: list[tuple[list[float], np.ndarray]] = []
149149
if nac_frequencies is not None:
150150
for freq in nac_frequencies:
151-
self.nac_frequencies.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1]))
151+
self.nac_frequencies.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1])) # type:ignore[arg-type]
152152
if nac_eigendisplacements is not None:
153153
for freq in nac_eigendisplacements:
154-
self.nac_eigendisplacements.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1]))
154+
self.nac_eigendisplacements.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1])) # type:ignore[arg-type]
155155

156156
def get_gamma_point(self) -> Kpoint | None:
157157
"""Get the Gamma q-point as a Kpoint object (or None if not found)."""
@@ -219,7 +219,7 @@ def has_nac(self) -> bool:
219219
@property
220220
def has_eigendisplacements(self) -> bool:
221221
"""True if eigendisplacements are present."""
222-
return len(self.eigendisplacements) > 0
222+
return len(self.eigendisplacements) > 0 # type:ignore[arg-type]
223223

224224
def get_nac_frequencies_along_dir(self, direction: Sequence) -> np.ndarray | None:
225225
"""Get the nac_frequencies for the given direction (not necessarily a versor).
@@ -353,7 +353,7 @@ class PhononBandStructureSymmLine(PhononBandStructure):
353353

354354
def __init__(
355355
self,
356-
qpoints: Sequence[Kpoint],
356+
qpoints: ArrayLike,
357357
frequencies: ArrayLike,
358358
lattice: Lattice,
359359
has_nac: bool = False,
@@ -388,7 +388,7 @@ def __init__(
388388
provide projections to the band structure.
389389
"""
390390
super().__init__(
391-
qpoints=qpoints,
391+
qpoints=qpoints, # type:ignore[arg-type]
392392
frequencies=frequencies,
393393
lattice=lattice,
394394
nac_frequencies=None,
@@ -398,7 +398,7 @@ def __init__(
398398
coords_are_cartesian=coords_are_cartesian,
399399
structure=structure,
400400
)
401-
self._reuse_init(eigendisplacements, frequencies, has_nac, qpoints)
401+
self._reuse_init(eigendisplacements, frequencies, has_nac, qpoints) # type:ignore[arg-type]
402402

403403
def __repr__(self) -> str:
404404
bands, labels = self.bands.shape, list(self.labels_dict)
@@ -425,7 +425,7 @@ def _reuse_init(
425425
self.distance += [previous_distance]
426426
else:
427427
self.distance += [
428-
np.linalg.norm(self.qpoints[idx].cart_coords - previous_qpoint.cart_coords) + previous_distance
428+
np.linalg.norm(self.qpoints[idx].cart_coords - previous_qpoint.cart_coords) + previous_distance # type:ignore[list-item]
429429
]
430430
previous_qpoint = self.qpoints[idx]
431431
previous_distance = self.distance[idx]
@@ -452,22 +452,22 @@ def _reuse_init(
452452
for idx in range(self.nb_qpoints):
453453
# get directions with nac irrespectively of the label_dict. NB: with labels
454454
# the gamma point is expected to appear twice consecutively.
455-
if np.allclose(qpoints[idx], (0, 0, 0)):
456-
if idx > 0 and not np.allclose(qpoints[idx - 1], (0, 0, 0)):
455+
if np.allclose(qpoints[idx], (0, 0, 0)): # type:ignore[arg-type]
456+
if idx > 0 and not np.allclose(qpoints[idx - 1], (0, 0, 0)): # type:ignore[arg-type]
457457
q_dir = self.qpoints[idx - 1]
458458
direction = q_dir.frac_coords / np.linalg.norm(q_dir.frac_coords)
459459
naf.append((direction, frequencies[:, idx]))
460460
if self.has_eigendisplacements:
461461
nac_eigendisplacements.append((direction, eigendisplacements[:, idx]))
462-
if idx < len(qpoints) - 1 and not np.allclose(qpoints[idx + 1], (0, 0, 0)):
462+
if idx < len(qpoints) - 1 and not np.allclose(qpoints[idx + 1], (0, 0, 0)): # type:ignore[arg-type]
463463
q_dir = self.qpoints[idx + 1]
464464
direction = q_dir.frac_coords / np.linalg.norm(q_dir.frac_coords)
465465
naf.append((direction, frequencies[:, idx]))
466466
if self.has_eigendisplacements:
467467
nac_eigendisplacements.append((direction, eigendisplacements[:, idx]))
468468

469-
self.nac_frequencies = np.array(naf, dtype=object)
470-
self.nac_eigendisplacements = np.array(nac_eigendisplacements, dtype=object)
469+
self.nac_frequencies = np.array(naf, dtype=object) # type:ignore[assignment]
470+
self.nac_eigendisplacements = np.array(nac_eigendisplacements, dtype=object) # type:ignore[assignment]
471471

472472
def get_equivalent_qpoints(self, index: int) -> list[int]:
473473
"""Get the list of qpoint indices equivalent (meaning they are the
@@ -586,7 +586,7 @@ def as_phononwebsite(self) -> dict:
586586
line_breaks.append((nq_start, nq))
587587
nq_start = nq
588588
else:
589-
dist += np.linalg.norm(q1 - q2)
589+
dist += np.linalg.norm(q1 - q2) # type:ignore[assignment]
590590
distances.append(dist)
591591
line_breaks.append((nq_start, len(qpoints)))
592592
dct["distances"] = distances
@@ -624,8 +624,8 @@ def band_reorder(self) -> None:
624624

625625
# Get order
626626
for nq in range(1, n_qpoints):
627-
old_eig_vecs = eigenvectors_from_displacements(eigen_displacements[:, nq - 1], atomic_masses)
628-
new_eig_vecs = eigenvectors_from_displacements(eigen_displacements[:, nq], atomic_masses)
627+
old_eig_vecs = eigenvectors_from_displacements(eigen_displacements[:, nq - 1], atomic_masses) # type:ignore[arg-type]
628+
new_eig_vecs = eigenvectors_from_displacements(eigen_displacements[:, nq], atomic_masses) # type:ignore[arg-type]
629629
order[nq] = estimate_band_connection(
630630
old_eig_vecs.reshape([n_phonons, n_phonons]).T,
631631
new_eig_vecs.reshape([n_phonons, n_phonons]).T,

src/pymatgen/phonon/dos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pymatgen.util.coord import get_linear_interpolated_value
1717

1818
if version.parse(np.__version__) < version.parse("2.0.0"):
19-
np.trapezoid = np.trapz # noqa: NPY201
19+
np.trapezoid = np.trapz # type:ignore[assignment] # noqa: NPY201
2020

2121
if TYPE_CHECKING:
2222
from numpy.typing import ArrayLike, NDArray
@@ -155,7 +155,7 @@ def as_dict(self) -> dict:
155155
@lazy_property
156156
def ind_zero_freq(self) -> int:
157157
"""Index of the first point for which the frequencies are >= 0."""
158-
ind = np.searchsorted(self.frequencies, 0)
158+
ind = np.searchsorted(self.frequencies, 0).astype(int)
159159
if ind >= len(self.frequencies):
160160
raise ValueError("No positive frequencies found")
161161
return ind

src/pymatgen/phonon/gruneisen.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class GruneisenParameter(MSONable):
4444
def __init__(
4545
self,
4646
qpoints: ArrayLike,
47-
gruneisen: ArrayLike[ArrayLike],
48-
frequencies: ArrayLike[ArrayLike],
47+
gruneisen: ArrayLike,
48+
frequencies: ArrayLike,
4949
multiplicities: Sequence | None = None,
5050
structure: Structure = None,
5151
lattice: Lattice = None,
@@ -248,10 +248,10 @@ class GruneisenPhononBandStructure(PhononBandStructure):
248248
def __init__(
249249
self,
250250
qpoints: ArrayLike,
251-
frequencies: ArrayLike[ArrayLike],
251+
frequencies: ArrayLike,
252252
gruneisenparameters: ArrayLike,
253253
lattice: Lattice,
254-
eigendisplacements: ArrayLike[ArrayLike] = None,
254+
eigendisplacements: ArrayLike = None,
255255
labels_dict: dict | None = None,
256256
coords_are_cartesian: bool = False,
257257
structure: Structure | None = None,
@@ -351,10 +351,10 @@ class GruneisenPhononBandStructureSymmLine(GruneisenPhononBandStructure, PhononB
351351
def __init__(
352352
self,
353353
qpoints: ArrayLike,
354-
frequencies: ArrayLike[ArrayLike],
354+
frequencies: ArrayLike,
355355
gruneisenparameters: ArrayLike,
356356
lattice: Lattice,
357-
eigendisplacements: ArrayLike[ArrayLike] = None,
357+
eigendisplacements: ArrayLike = None,
358358
labels_dict: dict | None = None,
359359
coords_are_cartesian: bool = False,
360360
structure: Structure | None = None,
@@ -397,10 +397,10 @@ def __init__(
397397

398398
PhononBandStructureSymmLine._reuse_init(
399399
self,
400-
eigendisplacements=eigendisplacements,
400+
eigendisplacements=eigendisplacements, # type:ignore[arg-type]
401401
frequencies=frequencies,
402402
has_nac=False,
403-
qpoints=qpoints,
403+
qpoints=qpoints, # type:ignore[arg-type]
404404
)
405405

406406
@classmethod

src/pymatgen/phonon/plotter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from matplotlib.axes import Axes
2727
from matplotlib.figure import Figure
28+
from numpy.typing import ArrayLike
2829

2930
from pymatgen.core import Structure
3031
from pymatgen.phonon.dos import PhononDos
@@ -219,7 +220,7 @@ def get_plot(
219220
if self.stack:
220221
ax.fill(xs, ys, color=color, label=str(key), **kwargs)
221222
else:
222-
ax.plot(xs, ys, color=color, label=str(key), linewidth=linewidth, **kwargs)
223+
ax.plot(xs, ys, color=color, label=str(key), linewidth=linewidth, **kwargs) # type:ignore[arg-type]
223224

224225
if xlim:
225226
ax.set_xlim(xlim)
@@ -529,15 +530,15 @@ def get_proj_plot(
529530
seg[:, :, 1] = self._bs.bands[:, d - 1 : d + 1] * u.factor
530531
seg[:, 0, 0] = k_dist[d - 1]
531532
seg[:, 1, 0] = k_dist[d]
532-
ls = LineCollection(seg, colors=colors, linestyles="-", linewidths=2.5)
533+
ls = LineCollection(seg, colors=colors, linestyles="-", linewidths=2.5) # type:ignore[arg-type]
533534
ax.add_collection(ls)
534535
if ylim is None:
535536
y_max: float = max(max(band) for band in self._bs.bands) * u.factor
536537
y_min: float = min(min(band) for band in self._bs.bands) * u.factor
537538
y_margin = (y_max - y_min) * 0.05
538539
ylim = (y_min - y_margin, y_max + y_margin)
539-
ax.set_ylim(ylim)
540-
xlim = [min(k_dist), max(k_dist)]
540+
ax.set_ylim(ylim) # type:ignore[arg-type]
541+
xlim = (min(k_dist), max(k_dist))
541542
ax.set_xlim(xlim)
542543
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=28)
543544
ylabel = rf"$\mathrm{{Frequencies\ ({u.label})}}$"
@@ -664,7 +665,7 @@ def plot_compare(
664665
on_incompatible: Literal["raise", "warn", "ignore"] = "raise",
665666
other_kwargs: dict | None = None,
666667
**kwargs,
667-
) -> Axes:
668+
) -> Axes | None:
668669
"""Plot two band structure for comparison. self in blue, others in red, green, ...
669670
The band structures need to be defined on the same symmetry lines!
670671
The distance between symmetry lines is determined by the band structure used to
@@ -722,7 +723,7 @@ def plot_compare(
722723
raise ValueError("The two band structures are not compatible.")
723724
if on_incompatible == "warn":
724725
logger.warning("The two band structures are not compatible.")
725-
return None # ignore/warn
726+
return None
726727

727728
color = colors[idx + 1] if colors else _colors[1 + idx % len(_colors)]
728729
_kwargs = kwargs.copy() # Don't set the color in kwargs, or every band will be red
@@ -794,7 +795,7 @@ def __init__(self, dos: PhononDos, structure: Structure = None) -> None:
794795
def _plot_thermo(
795796
self,
796797
func: Callable[[float, Structure | None], float],
797-
temperatures: Sequence[float],
798+
temperatures: ArrayLike,
798799
factor: float = 1,
799800
ax: Axes = None,
800801
ylabel: str | None = None,
@@ -824,7 +825,7 @@ def _plot_thermo(
824825
values = []
825826

826827
for temp in temperatures:
827-
values.append(func(temp, self.structure) * factor)
828+
values.append(func(temp, self.structure) * factor) # type:ignore[arg-type]
828829

829830
ax.plot(temperatures, values, label=label, **kwargs)
830831

0 commit comments

Comments
 (0)