Skip to content

Commit 4bb4653

Browse files
authored
Add element and site projected phonon DOS support (#339)
* Add element and site projected phonon DOS support - extend phonon_dos() with CompletePhononDos projection modes and optional total overlays - harden DOS normalization with explicit error handling for zero-density and invalid integral inputs - expand phonon plotly coverage with projected DOS behavior and normalization edge-case tests - update phonon fixtures, example scripts, and README links for the new projected DOS workflow * Refine phonon DOS projection handling - simplify phonon_dos normalization and projection control flow - support CompletePhononDos total plotting when project is None - add regression test and refresh README source links from hooks * Simplify phonon DOS projection flow - streamline phonon_dos projection handling with one validated branch - reduce repetition in phonon DOS example plotting via looped scenarios - preserve behavior while keeping tests and type checks green * Simplify stacked DOS density accumulation - Flatten nested ternary into single conditional - Use np.zeros_like fallback instead of None check
1 parent f749c72 commit 4bb4653

20 files changed

+421
-91
lines changed

assets/scripts/phonons/phonon_bands.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
from pymatviz.utils.testing import TEST_FILES
1515

1616

17-
try:
18-
import atomate2 # noqa: F401
19-
except ImportError:
20-
raise SystemExit(0) from None # need atomate2 for MontyDecoder to load PhononDBDoc
21-
22-
2317
pmv.set_plotly_template("pymatviz_white")
2418

2519

@@ -39,7 +33,7 @@
3933
if "m3gnet" in path
4034
else "PBE"
4135
)
42-
with zopen(path) as file:
36+
with zopen(path, mode="rt") as file:
4337
docs[model_label] = json.loads(file.read(), cls=MontyDecoder)
4438

4539
ph_bands: dict[str, PhononBands] = {

assets/scripts/phonons/phonon_bands_and_dos.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
from pymatviz.utils.testing import TEST_FILES
1515

1616

17-
try:
18-
import atomate2 # noqa: F401
19-
except ImportError:
20-
raise SystemExit(0) from None # need atomate2 for MontyDecoder to load PhononDBDoc
21-
22-
2317
# %% Plot phonon bands and DOS
2418
for mp_id, formula in (
2519
("mp-2758", "Sr4Se4"),
@@ -28,7 +22,7 @@
2822
docs: dict[str, PhononDBDoc] = {}
2923
for path in glob(f"{TEST_FILES}/phonons/{mp_id}-{formula}-*.json.xz"):
3024
key = path.split("-")[-1].split(".")[0]
31-
with zopen(path) as file:
25+
with zopen(path, mode="rt") as file:
3226
docs[key] = json.loads(file.read(), cls=MontyDecoder)
3327

3428
ph_bands: dict[str, PhononBands] = {

assets/scripts/phonons/phonon_dos.py

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,48 @@
33
# %%
44
import json
55
from glob import glob
6+
from typing import Any, cast
67

8+
import plotly.graph_objects as go
79
from monty.io import zopen
810
from monty.json import MontyDecoder
9-
from pymatgen.phonon.dos import PhononDos
11+
from pymatgen.io.phonopy import get_pmg_structure
12+
from pymatgen.phonon.dos import CompletePhononDos, PhononDos
1013

1114
import pymatviz as pmv
1215
from pymatviz.phonons.helpers import PhononDBDoc
1316
from pymatviz.utils.testing import TEST_FILES, load_phonopy_nacl
1417

1518

16-
try:
17-
import atomate2 # noqa: F401
18-
except ImportError:
19-
raise SystemExit(0) from None # need atomate2 for MontyDecoder to load PhononDBDoc
19+
pmv.set_plotly_template("pymatviz_white")
20+
21+
22+
class PhonopyDosMissingError(RuntimeError):
23+
"""Raised when phonopy fails to compute a required DOS output."""
2024

2125

22-
# %% Plot phonon bands and DOS
26+
def show_figure(plotly_figure: go.Figure, title: str, *, y_pos: float = 0.97) -> None:
27+
"""Apply consistent layout settings and display the figure."""
28+
plotly_figure.layout.title = dict(text=title, x=0.5, y=y_pos)
29+
plotly_figure.layout.margin = dict(l=0, r=0, b=0, t=40)
30+
plotly_figure.show()
31+
32+
33+
# %% Plot phonon DOS (total)
2334
for mp_id, formula in (
2435
("mp-2758", "Sr4Se4"),
2536
("mp-23907", "H2"),
2637
):
2738
docs: dict[str, PhononDBDoc] = {}
2839
for path in glob(f"{TEST_FILES}/phonons/{mp_id}-{formula}-*.json.xz"):
2940
key = path.split("-", maxsplit=3)[-1].split(".")[0]
30-
with zopen(path) as file:
41+
with zopen(path, mode="rt") as file:
3142
docs[key] = json.loads(file.read(), cls=MontyDecoder)
3243

3344
ph_doses: dict[str, PhononDos] = {key: doc.phonon_dos for key, doc in docs.items()}
3445

3546
fig = pmv.phonon_dos(ph_doses)
36-
fig.layout.title = dict(text=f"Phonon DOS of {formula} ({mp_id})", x=0.5, y=0.98)
37-
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
38-
fig.show()
47+
show_figure(fig, f"Phonon DOS of {formula} ({mp_id})", y_pos=0.98)
3948
# pmv.io.save_and_compress_svg(fig, f"phonon-dos-{mp_id}")
4049

4150

@@ -49,11 +58,66 @@
4958
phonopy_nacl = load_phonopy_nacl()
5059
phonopy_nacl.run_mesh([10, 10, 10])
5160
phonopy_nacl.run_total_dos()
61+
if phonopy_nacl.total_dos is None:
62+
raise PhonopyDosMissingError
5263

5364
plt = phonopy_nacl.plot_total_dos()
5465
plt.title("NaCl DOS plotted by phonopy")
5566

5667
fig = pmv.phonon_dos(phonopy_nacl.total_dos)
57-
fig.layout.title = dict(text="NaCl DOS plotted by pymatviz", x=0.5, y=0.97)
58-
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
59-
fig.show()
68+
show_figure(fig, "NaCl DOS plotted by pymatviz")
69+
70+
71+
# %% Element-projected phonon DOS from phonopy
72+
# Build a CompletePhononDos from phonopy's projected DOS
73+
phonopy_nacl_pdos = load_phonopy_nacl()
74+
phonopy_nacl_pdos.run_mesh([10, 10, 10], with_eigenvectors=True, is_mesh_symmetry=False)
75+
phonopy_nacl_pdos.run_projected_dos()
76+
phonopy_nacl_pdos.run_total_dos()
77+
if phonopy_nacl_pdos.total_dos is None:
78+
raise PhonopyDosMissingError
79+
if phonopy_nacl_pdos.projected_dos is None:
80+
raise PhonopyDosMissingError
81+
82+
struct = get_pmg_structure(phonopy_nacl_pdos.primitive)
83+
total_dos = PhononDos(
84+
phonopy_nacl_pdos.total_dos.frequency_points,
85+
phonopy_nacl_pdos.total_dos.dos,
86+
)
87+
site_dos = {
88+
site: phonopy_nacl_pdos.projected_dos.projected_dos[idx]
89+
for idx, site in enumerate(struct)
90+
}
91+
complete_dos = CompletePhononDos(struct, total_dos, site_dos)
92+
93+
94+
# %% Element-projected DOS (default: with total overlay)
95+
projected_examples: list[tuple[str, dict[str, str | bool]]] = [
96+
("NaCl Element-Projected Phonon DOS", {"project": "element"}),
97+
(
98+
"NaCl Element-Projected Phonon DOS (no total)",
99+
{"project": "element", "show_total": False},
100+
),
101+
(
102+
"NaCl Element-Projected Phonon DOS (stacked)",
103+
{"project": "element", "stack": True, "show_total": False},
104+
),
105+
("NaCl Site-Projected Phonon DOS", {"project": "site"}),
106+
(
107+
"NaCl Element-Projected Phonon DOS (normalized)",
108+
{"project": "element", "normalize": "max"},
109+
),
110+
]
111+
for plot_title, plot_kwargs in projected_examples:
112+
fig = pmv.phonon_dos(complete_dos, **cast("dict[str, Any]", plot_kwargs))
113+
show_figure(fig, plot_title)
114+
115+
# pmv.io.save_and_compress_svg(fig, "phonon-dos-element-projected")
116+
# pmv.io.save_and_compress_svg(fig, "phonon-dos-site-projected")
117+
118+
119+
# %% Comparing multiple models with element projection
120+
dos_dict = {"model A": complete_dos, "model B": complete_dos}
121+
fig = pmv.phonon_dos(dos_dict, project="element")
122+
show_figure(fig, "NaCl Multi-Model Element-Projected DOS")
123+
# pmv.io.save_and_compress_svg(fig, "phonon-dos-multi-model-element-projected")

pymatviz/phonons/figures.py

Lines changed: 120 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import plotly.graph_objects as go
1313
import scipy.constants as const
1414
from plotly.subplots import make_subplots
15-
from pymatgen.phonon.dos import PhononDos
15+
from pymatgen.phonon.dos import CompletePhononDos, PhononDos
1616

1717
from pymatviz.phonons.helpers import (
1818
AnyBandStructure,
@@ -361,90 +361,168 @@ def phonon_bands(
361361

362362

363363
def phonon_dos(
364-
doses: AnyDos | Mapping[str, AnyDos],
364+
doses: AnyDos | CompletePhononDos | Mapping[str, AnyDos | CompletePhononDos],
365365
*,
366366
stack: bool = False,
367367
sigma: float = 0,
368368
units: Literal["THz", "eV", "meV", "Ha", "cm-1"] = "THz",
369369
normalize: Literal["max", "sum", "integral"] | None = None,
370370
last_peak_anno: str | None = None,
371+
project: Literal["element", "site"] | None = None,
372+
show_total: bool = True,
371373
**kwargs: Any,
372374
) -> go.Figure:
373375
"""Plot phonon DOS using Plotly.
374376
375377
Args:
376-
doses (AnyDos | dict[str, AnyDos]): pymatgen
377-
PhononDos or phonopy TotalDos or dict of multiple of either.
378+
doses (AnyDos | CompletePhononDos | dict[str, AnyDos | CompletePhononDos]):
379+
pymatgen PhononDos, CompletePhononDos, phonopy TotalDos, or dict of these.
378380
stack (bool): Whether to plot the DOS as a stacked area graph. Defaults to
379381
False.
380-
sigma (float): Standard deviation for Gaussian smearing. Defaults to None.
382+
sigma (float): Standard deviation for Gaussian smearing. Defaults to 0.
381383
units (str): Units for the frequencies. Defaults to "THz".
382-
legend (dict): Legend configuration.
383-
normalize (bool): Whether to normalize the DOS. Defaults to False.
384+
normalize (str | None): Normalization mode. One of "max", "sum", "integral",
385+
or None. Defaults to None.
384386
last_peak_anno (str): Annotation for last DOS peak with f-string placeholders
385387
for key (of dict containing multiple DOSes), last_peak frequency and units.
386388
Defaults to None, meaning last peak annotation is disabled. Set to "" to
387389
enable with a sensible default string.
390+
project (str | None): Projection mode for CompletePhononDos.
391+
"element" decomposes into per-element partial DOS, "site" into per-site
392+
partial DOS. Requires CompletePhononDos input. Defaults to None (plot
393+
total DOS only).
394+
show_total (bool): When projecting, overlay the total DOS as a dashed gray line.
395+
Only used when project is not None. Defaults to True.
388396
**kwargs: Passed to Plotly's Figure.add_scatter method.
389397
390398
Returns:
391399
go.Figure: Plotly figure object.
400+
401+
Raises:
402+
TypeError: If project is set but input is not CompletePhononDos.
392403
"""
393404
valid_normalize = (None, "max", "sum", "integral")
394405
if normalize not in valid_normalize:
395406
raise ValueError(f"Invalid {normalize=}, must be one of {valid_normalize}.")
407+
if project not in (None, "element", "site"):
408+
raise ValueError(f"Invalid {project=}, must be 'element' or 'site'")
409+
raw_doses = (
410+
cast("Mapping[str, AnyDos | CompletePhononDos]", doses)
411+
if isinstance(doses, Mapping)
412+
else {"": doses}
413+
)
396414

397-
input_doses = doses if isinstance(doses, Mapping) else {"": doses}
398415
dos_dict: dict[str, PhononDos] = {}
399-
for key, dos in input_doses.items():
400-
cls_name = f"{type(dos).__module__}.{type(dos).__qualname__}"
401-
if cls_name == "phonopy.phonon.dos.TotalDos":
402-
# Cast to Any to access phonopy TotalDos attributes
403-
phonopy_dos = cast("Any", dos)
404-
dos_dict[key] = PhononDos( # type: ignore[index]
405-
frequencies=phonopy_dos.frequency_points,
406-
densities=phonopy_dos.dos,
407-
)
408-
elif isinstance(dos, PhononDos):
409-
dos_dict[key] = dos # type: ignore[index]
410-
else:
416+
total_overlay_dict: dict[str, PhononDos] = {}
417+
for label, raw_dos in raw_doses.items():
418+
label_prefix = f"{label} - " if label else ""
419+
if project is None:
420+
cls_name = f"{type(raw_dos).__module__}.{type(raw_dos).__qualname__}"
421+
if cls_name == "phonopy.phonon.dos.TotalDos":
422+
phonopy_total_dos = cast("Any", raw_dos)
423+
dos_dict[label] = PhononDos(
424+
frequencies=phonopy_total_dos.frequency_points,
425+
densities=phonopy_total_dos.dos,
426+
)
427+
elif isinstance(raw_dos, CompletePhononDos):
428+
dos_dict[label] = PhononDos(raw_dos.frequencies, raw_dos.densities)
429+
elif isinstance(raw_dos, PhononDos):
430+
dos_dict[label] = raw_dos
431+
else:
432+
raise TypeError(
433+
f"Only {PhononDos.__name__}, {CompletePhononDos.__name__}, "
434+
"phonopy TotalDos, or dict of these supported, "
435+
f"got {type(raw_dos).__name__}"
436+
)
437+
continue
438+
if not isinstance(raw_dos, CompletePhononDos):
411439
raise TypeError(
412-
f"Only {PhononDos.__name__} or dict supported, got {type(dos).__name__}"
440+
f"project={project!r} requires CompletePhononDos, "
441+
f"got {type(raw_dos).__name__} for key {label!r}"
413442
)
414-
if len(dos_dict) == 0:
443+
projected_dos = (
444+
raw_dos.get_element_dos()
445+
if project == "element"
446+
else {
447+
f"{site.specie}{site_idx}": raw_dos.get_site_dos(site)
448+
for site_idx, site in enumerate(raw_dos.structure)
449+
}
450+
)
451+
dos_dict |= {f"{label_prefix}{key}": dos for key, dos in projected_dos.items()}
452+
if show_total:
453+
total_overlay_dict[f"{label_prefix}Total"] = PhononDos(
454+
raw_dos.frequencies, raw_dos.densities
455+
)
456+
457+
if not dos_dict:
415458
raise ValueError("Empty DOS dict")
416459

417460
if last_peak_anno == "":
418461
last_peak_anno = "ω<sub>{key}</sub></span>={last_peak:.1f} {units}"
419462

420-
fig = go.Figure()
421-
422-
for key, dos in dos_dict.items():
423-
frequencies = dos.frequencies
463+
def _prepare_dos(dos: PhononDos) -> tuple[np.ndarray, np.ndarray]:
464+
"""Convert frequencies and apply smearing + normalization."""
465+
frequencies = convert_frequencies(dos.frequencies, units)
424466
densities = dos.get_smeared_densities(sigma)
425-
426-
# convert frequencies to specified units
427-
frequencies = convert_frequencies(frequencies, units)
428-
429-
# normalize DOS
430-
if normalize == "max":
431-
densities /= densities.max()
432-
elif normalize == "sum":
433-
densities /= densities.sum()
467+
if normalize in ("max", "sum"):
468+
density_norm = densities.max() if normalize == "max" else densities.sum()
469+
if density_norm == 0:
470+
msg_key = "max density" if normalize == "max" else "sum density"
471+
raise ValueError(
472+
f"Cannot normalize DOS with mode={normalize!r}: {msg_key} is 0."
473+
)
474+
densities = densities / density_norm
434475
elif normalize == "integral":
476+
if len(frequencies) < 2:
477+
raise ValueError(
478+
"Cannot normalize DOS with mode='integral': "
479+
"need >=2 frequency points."
480+
)
435481
bin_width = frequencies[1] - frequencies[0]
436-
densities = densities / densities.sum() / bin_width
482+
if bin_width == 0:
483+
raise ValueError(
484+
"Cannot normalize DOS with mode='integral': bin width is 0."
485+
)
486+
density_norm = densities.sum()
487+
if density_norm == 0:
488+
raise ValueError(
489+
"Cannot normalize DOS with mode='integral': sum density is 0."
490+
)
491+
densities = densities / density_norm / bin_width
492+
return frequencies, densities
437493

438-
scatter_defaults = dict(mode="lines")
494+
fig = go.Figure()
495+
cumulative_density_by_group: dict[str, np.ndarray] = {}
496+
for dos_name, dos_obj in dos_dict.items():
497+
frequencies, densities = _prepare_dos(dos_obj)
498+
scatter_kwargs: dict[str, Any] = {"mode": "lines"}
439499
if stack:
440-
if fig.data: # for stacked plots, accumulate densities
441-
densities += fig.data[-1].y
442-
scatter_defaults.setdefault("fill", "tonexty")
443-
500+
stack_group = (
501+
""
502+
if project is None or " - " not in dos_name
503+
else dos_name.split(" - ", maxsplit=1)[0]
504+
)
505+
densities = densities + cumulative_density_by_group.get(
506+
stack_group, np.zeros_like(densities)
507+
)
508+
cumulative_density_by_group[stack_group] = densities
509+
scatter_kwargs["fill"] = "tonexty"
444510
fig.add_scatter(
445-
x=frequencies, y=densities, name=key, **scatter_defaults | kwargs
511+
x=frequencies, y=densities, name=dos_name, **scatter_kwargs | kwargs
446512
)
447513

514+
if project is not None and show_total:
515+
for total_name, total_dos in total_overlay_dict.items():
516+
frequencies, densities = _prepare_dos(total_dos)
517+
fig.add_scatter(
518+
x=frequencies,
519+
y=densities,
520+
name=total_name,
521+
mode="lines",
522+
line=dict(dash="dash", color="gray", width=1.5),
523+
showlegend=True,
524+
)
525+
448526
fig.layout.xaxis.update(title=f"Frequency ({units})")
449527
fig.layout.yaxis.update(title="Density of States", rangemode="tozero")
450528
fig.layout.margin = dict(t=5, b=5, l=5, r=5)

0 commit comments

Comments
 (0)