Skip to content

Commit b45666b

Browse files
authored
Add PhononDos.mae() and PhononBandStructure.has_imaginary_gamma_freq() methods (#3520)
* fix typos and improve code readability * clarify has_nac doc str * add PhononBandStructure.has_imaginary_gamma_freq() to check for imaginary modes at the gamma point * improve tests for phBS.has_imaginary_freq and add new ones for has_imaginary_gamma_freq * add PhononDos.mae() method to compare two DOSs * add TestPhononDos.test_mae() * increase default has_imaginary_freq tol from 1e-5 to 1e-3
1 parent 131c455 commit b45666b

File tree

6 files changed

+106
-46
lines changed

6 files changed

+106
-46
lines changed

pymatgen/electronic_structure/bandstructure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def get_equivalent_kpoints(self, index):
779779
TODO: now it uses the label we might want to use coordinates instead
780780
(in case there was a mislabel)
781781
"""
782-
# if the kpoint has no label it can"t have a repetition along the band
782+
# if the kpoint has no label it can't have a repetition along the band
783783
# structure line object
784784

785785
if self.kpoints[index].label is None:

pymatgen/phonon/bandstructure.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ def get_reasonable_repetitions(n_atoms: int) -> tuple[int, int, int]:
2424
according to the number of atoms in the system.
2525
"""
2626
if n_atoms < 4:
27-
return (3, 3, 3)
27+
return 3, 3, 3
2828
if 4 <= n_atoms < 15:
29-
return (2, 2, 2)
29+
return 2, 2, 2
3030
if 15 <= n_atoms < 50:
31-
return (2, 2, 1)
31+
return 2, 2, 1
3232

33-
return (1, 1, 1)
33+
return 1, 1, 1
3434

3535

3636
def eigenvectors_from_displacements(disp, masses) -> np.ndarray:
@@ -137,8 +137,8 @@ def __init__(
137137
self.nb_qpoints = len(self.qpoints)
138138

139139
# normalize directions for nac_frequencies and nac_eigendisplacements
140-
self.nac_frequencies = []
141-
self.nac_eigendisplacements = []
140+
self.nac_frequencies: list[tuple[list[float], np.ndarray]] = []
141+
self.nac_eigendisplacements: list[tuple[list[float], np.ndarray]] = []
142142
if nac_frequencies is not None:
143143
for freq in nac_frequencies:
144144
self.nac_frequencies.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1]))
@@ -152,13 +152,29 @@ def min_freq(self) -> tuple[Kpoint, float]:
152152

153153
return self.qpoints[idx[1]], self.bands[idx]
154154

155-
def has_imaginary_freq(self, tol: float = 1e-5) -> bool:
156-
"""True if imaginary frequencies are present in the BS."""
155+
def has_imaginary_freq(self, tol: float = 1e-3) -> bool:
156+
"""True if imaginary frequencies are present anywhere in the band structure. Always True if
157+
has_imaginary_gamma_freq is True.
158+
159+
Args:
160+
tol: Tolerance for determining if a frequency is imaginary. Defaults to 1e-3.
161+
"""
157162
return self.min_freq()[1] + tol < 0
158163

164+
def has_imaginary_gamma_freq(self, tol: float = 1e-3) -> bool:
165+
"""Checks if there are imaginary modes at the gamma point.
166+
167+
Args:
168+
tol: Tolerance for determining if a frequency is imaginary. Defaults to 1e-3.
169+
"""
170+
gamma_freqs = self.bands[:, 0] # frequencies at the Gamma point
171+
return any(freq < -tol for freq in gamma_freqs)
172+
159173
@property
160174
def has_nac(self) -> bool:
161-
"""True if nac_frequencies are present."""
175+
"""True if nac_frequencies are present (i.e. the band structure has been
176+
calculated taking into account Born-charge-derived non-analytical corrections at Gamma).
177+
"""
162178
return len(self.nac_frequencies) > 0
163179

164180
@property
@@ -177,10 +193,10 @@ def get_nac_frequencies_along_dir(self, direction: Sequence) -> np.ndarray | Non
177193
the frequencies as a numpy array o(3*len(structure), len(qpoints)).
178194
None if not found.
179195
"""
180-
versor = [i / np.linalg.norm(direction) for i in direction]
181-
for d, f in self.nac_frequencies:
182-
if np.allclose(versor, d):
183-
return f
196+
versor = [idx / np.linalg.norm(direction) for idx in direction]
197+
for dist, freq in self.nac_frequencies:
198+
if np.allclose(versor, dist):
199+
return freq
184200

185201
return None
186202

@@ -195,10 +211,10 @@ def get_nac_eigendisplacements_along_dir(self, direction) -> np.ndarray | None:
195211
the eigendisplacements as a numpy array of complex numbers with shape
196212
(3*len(structure), len(structure), 3). None if not found.
197213
"""
198-
versor = [i / np.linalg.norm(direction) for i in direction]
199-
for d, e in self.nac_eigendisplacements:
200-
if np.allclose(versor, d):
201-
return e
214+
versor = [idx / np.linalg.norm(direction) for idx in direction]
215+
for dist, eigen_disp in self.nac_eigendisplacements:
216+
if np.allclose(versor, dist):
217+
return eigen_disp
202218

203219
return None
204220

@@ -426,45 +442,40 @@ def get_equivalent_qpoints(self, index: int) -> list[int]:
426442
TODO: now it uses the label we might want to use coordinates instead
427443
(in case there was a mislabel)
428444
"""
429-
# if the qpoint has no label it can"t have a repetition along the band
445+
# if the qpoint has no label it can't have a repetition along the band
430446
# structure line object
431447

432448
if self.qpoints[index].label is None:
433449
return [index]
434450

435451
list_index_qpoints = []
436-
for i in range(self.nb_qpoints):
437-
if self.qpoints[i].label == self.qpoints[index].label:
438-
list_index_qpoints.append(i)
452+
for idx in range(self.nb_qpoints):
453+
if self.qpoints[idx].label == self.qpoints[index].label:
454+
list_index_qpoints.append(idx)
439455

440456
return list_index_qpoints
441457

442-
def get_branch(self, index: int) -> list[dict]:
443-
r"""Returns in what branch(es) is the qpoint. There can be several
444-
branches.
458+
def get_branch(self, index: int) -> list[dict[str, str | int]]:
459+
r"""Returns in what branch(es) is the qpoint. There can be several branches.
445460
446461
Args:
447-
index: the qpoint index
462+
index (int): the qpoint index
448463
449464
Returns:
450-
A list of dictionaries [{"name","start_index","end_index","index"}]
451-
indicating all branches in which the qpoint is. It takes into
452-
account the fact that one qpoint (e.g., \\Gamma) can be in several
453-
branches
465+
list[dict[str, str | int]]: [{"name","start_index","end_index","index"}]
466+
indicating all branches in which the qpoint is. It takes into
467+
account the fact that one qpoint (e.g., \\Gamma) can be in several
468+
branches
454469
"""
455-
to_return = []
456-
for i in self.get_equivalent_qpoints(index):
457-
for b in self.branches:
458-
if b["start_index"] <= i <= b["end_index"]:
459-
to_return.append(
460-
{
461-
"name": b["name"],
462-
"start_index": b["start_index"],
463-
"end_index": b["end_index"],
464-
"index": i,
465-
}
470+
lst = []
471+
for pt_idx in self.get_equivalent_qpoints(index):
472+
for branch in self.branches:
473+
start_idx, end_idx = branch["start_index"], branch["end_index"]
474+
if start_idx <= pt_idx <= end_idx:
475+
lst.append(
476+
{"name": branch["name"], "start_index": start_idx, "end_index": end_idx, "index": pt_idx}
466477
)
467-
return to_return
478+
return lst
468479

469480
def write_phononwebsite(self, filename: str | PathLike) -> None:
470481
"""Write a json file for the phononwebsite:
@@ -606,13 +617,13 @@ def from_dict(cls, dct: dict) -> PhononBandStructureSymmLine:
606617
eigendisplacements = (
607618
np.array(dct["eigendisplacements"]["real"]) + np.array(dct["eigendisplacements"]["imag"]) * 1j
608619
)
609-
structure = Structure.from_dict(dct["structure"]) if "structure" in dct else None
620+
struct = Structure.from_dict(dct["structure"]) if "structure" in dct else None
610621
return cls(
611622
dct["qpoints"],
612623
np.array(dct["bands"]),
613624
lattice_rec,
614625
dct["has_nac"],
615626
eigendisplacements,
616627
dct["labels_dict"],
617-
structure=structure,
628+
structure=struct,
618629
)

pymatgen/phonon/dos.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,28 @@ def zero_point_energy(self, structure: Structure | None = None) -> float:
329329

330330
return zpe
331331

332+
def mae(self, other: PhononDos, two_sided: bool = True) -> float:
333+
"""Mean absolute error between two DOSs.
334+
335+
Args:
336+
other: Another DOS object.
337+
two_sided: Whether to calculate the two-sided MAE meaning interpolate each DOS to the
338+
other's frequencies and averaging the two MAEs. Defaults to True.
339+
340+
Returns:
341+
float: Mean absolute error.
342+
"""
343+
# Interpolate other.densities to align with self.frequencies
344+
self_interpolated = np.interp(self.frequencies, other.frequencies, other.densities)
345+
self_mae = np.abs(self.densities - self_interpolated).mean()
346+
347+
if two_sided:
348+
other_interpolated = np.interp(other.frequencies, self.frequencies, self.densities)
349+
other_mae = np.abs(other.densities - other_interpolated).mean()
350+
return (self_mae + other_mae) / 2
351+
352+
return self_mae
353+
332354

333355
class CompletePhononDos(PhononDos):
334356
"""This wrapper class defines a total dos, and also provides a list of PDos.

pymatgen/phonon/plotter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def _make_ticks(self, ax: Axes) -> Axes:
285285
ax.set_xticks(ticks_labels[0])
286286
ax.set_xticklabels(ticks_labels[1])
287287

288+
# plot vertical lines at each of the ticks
288289
for idx, label in enumerate(ticks["label"]):
289290
if label is not None:
290291
ax.axvline(ticks["distance"][idx], color="black")

tests/phonon/test_bandstructure.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,27 @@ def test_basic(self):
3838

3939
assert list(self.bs.min_freq()[0].frac_coords) == [0, 0, 0]
4040
assert self.bs.min_freq()[1] == approx(-0.03700895020)
41-
assert self.bs.has_imaginary_freq()
42-
assert not self.bs.has_imaginary_freq(tol=0.5)
4341
assert_allclose(self.bs.asr_breaking(), [-0.0370089502, -0.0370089502, -0.0221388897])
4442

4543
assert self.bs.nb_bands == 6
4644
assert self.bs.nb_qpoints == 204
4745

4846
assert_allclose(self.bs.qpoints[1].frac_coords, [0.01, 0, 0])
4947

48+
def test_has_imaginary_freq(self):
49+
for tol in (0, 0.01, 0.02, 0.03, 0.04, 0.05):
50+
assert self.bs.has_imaginary_freq(tol=tol) == (tol < 0.04)
51+
52+
for tol in (0, 0.01, 0.02, 0.03, 0.04, 0.05):
53+
assert self.bs2.has_imaginary_freq(tol=tol) == (tol < 0.01)
54+
55+
# test Gamma point imaginary frequency detection
56+
for tol in (0, 0.01, 0.02, 0.03, 0.04, 0.05):
57+
assert self.bs.has_imaginary_gamma_freq(tol=tol) == (tol < 0.04)
58+
59+
for tol in (0, 0.01, 0.02, 0.03, 0.04, 0.05):
60+
assert self.bs2.has_imaginary_gamma_freq(tol=tol) == (tol < 0.01)
61+
5062
def test_nac(self):
5163
assert self.bs.has_nac
5264
assert not self.bs2.has_nac

tests/phonon/test_dos.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import re
55

6+
import pytest
67
from pytest import approx
78

89
from pymatgen.core import Element
@@ -85,6 +86,19 @@ def test_eq(self):
8586
assert self.dos != 2 * self.dos
8687
assert 2 * self.dos == self.dos + self.dos
8788

89+
def test_mae(self):
90+
assert self.dos.mae(self.dos) == 0
91+
assert self.dos.mae(self.dos + 1) == 1
92+
assert self.dos.mae(self.dos - 1) == 1
93+
assert self.dos.mae(2 * self.dos) == pytest.approx(0.786546967)
94+
assert (2 * self.dos).mae(self.dos) == pytest.approx(0.786546967)
95+
96+
# test two_sided=False after shifting DOS freqs so MAE requires interpolation
97+
dos2 = PhononDos(self.dos.frequencies + 0.01, self.dos.densities)
98+
assert self.dos.mae(dos2 + 1, two_sided=False) == pytest.approx(0.999999999)
99+
assert self.dos.mae(dos2 - 1, two_sided=False) == pytest.approx(1.00000000000031)
100+
assert self.dos.mae(2 * dos2, two_sided=False) == pytest.approx(0.786546967)
101+
88102

89103
class TestCompletePhononDos(PymatgenTest):
90104
def setUp(self):

0 commit comments

Comments
 (0)