|
12 | 12 | import plotly.graph_objects as go |
13 | 13 | import scipy.constants as const |
14 | 14 | from plotly.subplots import make_subplots |
15 | | -from pymatgen.phonon.dos import PhononDos |
| 15 | +from pymatgen.phonon.dos import CompletePhononDos, PhononDos |
16 | 16 |
|
17 | 17 | from pymatviz.phonons.helpers import ( |
18 | 18 | AnyBandStructure, |
@@ -361,90 +361,168 @@ def phonon_bands( |
361 | 361 |
|
362 | 362 |
|
363 | 363 | def phonon_dos( |
364 | | - doses: AnyDos | Mapping[str, AnyDos], |
| 364 | + doses: AnyDos | CompletePhononDos | Mapping[str, AnyDos | CompletePhononDos], |
365 | 365 | *, |
366 | 366 | stack: bool = False, |
367 | 367 | sigma: float = 0, |
368 | 368 | units: Literal["THz", "eV", "meV", "Ha", "cm-1"] = "THz", |
369 | 369 | normalize: Literal["max", "sum", "integral"] | None = None, |
370 | 370 | last_peak_anno: str | None = None, |
| 371 | + project: Literal["element", "site"] | None = None, |
| 372 | + show_total: bool = True, |
371 | 373 | **kwargs: Any, |
372 | 374 | ) -> go.Figure: |
373 | 375 | """Plot phonon DOS using Plotly. |
374 | 376 |
|
375 | 377 | Args: |
376 | | - doses (AnyDos | dict[str, AnyDos]): pymatgen |
377 | | - PhononDos or phonopy TotalDos or dict of multiple of either. |
| 378 | + doses (AnyDos | CompletePhononDos | dict[str, AnyDos | CompletePhononDos]): |
| 379 | + pymatgen PhononDos, CompletePhononDos, phonopy TotalDos, or dict of these. |
378 | 380 | stack (bool): Whether to plot the DOS as a stacked area graph. Defaults to |
379 | 381 | False. |
380 | | - sigma (float): Standard deviation for Gaussian smearing. Defaults to None. |
| 382 | + sigma (float): Standard deviation for Gaussian smearing. Defaults to 0. |
381 | 383 | units (str): Units for the frequencies. Defaults to "THz". |
382 | | - legend (dict): Legend configuration. |
383 | | - normalize (bool): Whether to normalize the DOS. Defaults to False. |
| 384 | + normalize (str | None): Normalization mode. One of "max", "sum", "integral", |
| 385 | + or None. Defaults to None. |
384 | 386 | last_peak_anno (str): Annotation for last DOS peak with f-string placeholders |
385 | 387 | for key (of dict containing multiple DOSes), last_peak frequency and units. |
386 | 388 | Defaults to None, meaning last peak annotation is disabled. Set to "" to |
387 | 389 | enable with a sensible default string. |
| 390 | + project (str | None): Projection mode for CompletePhononDos. |
| 391 | + "element" decomposes into per-element partial DOS, "site" into per-site |
| 392 | + partial DOS. Requires CompletePhononDos input. Defaults to None (plot |
| 393 | + total DOS only). |
| 394 | + show_total (bool): When projecting, overlay the total DOS as a dashed gray line. |
| 395 | + Only used when project is not None. Defaults to True. |
388 | 396 | **kwargs: Passed to Plotly's Figure.add_scatter method. |
389 | 397 |
|
390 | 398 | Returns: |
391 | 399 | go.Figure: Plotly figure object. |
| 400 | +
|
| 401 | + Raises: |
| 402 | + TypeError: If project is set but input is not CompletePhononDos. |
392 | 403 | """ |
393 | 404 | valid_normalize = (None, "max", "sum", "integral") |
394 | 405 | if normalize not in valid_normalize: |
395 | 406 | raise ValueError(f"Invalid {normalize=}, must be one of {valid_normalize}.") |
| 407 | + if project not in (None, "element", "site"): |
| 408 | + raise ValueError(f"Invalid {project=}, must be 'element' or 'site'") |
| 409 | + raw_doses = ( |
| 410 | + cast("Mapping[str, AnyDos | CompletePhononDos]", doses) |
| 411 | + if isinstance(doses, Mapping) |
| 412 | + else {"": doses} |
| 413 | + ) |
396 | 414 |
|
397 | | - input_doses = doses if isinstance(doses, Mapping) else {"": doses} |
398 | 415 | dos_dict: dict[str, PhononDos] = {} |
399 | | - for key, dos in input_doses.items(): |
400 | | - cls_name = f"{type(dos).__module__}.{type(dos).__qualname__}" |
401 | | - if cls_name == "phonopy.phonon.dos.TotalDos": |
402 | | - # Cast to Any to access phonopy TotalDos attributes |
403 | | - phonopy_dos = cast("Any", dos) |
404 | | - dos_dict[key] = PhononDos( # type: ignore[index] |
405 | | - frequencies=phonopy_dos.frequency_points, |
406 | | - densities=phonopy_dos.dos, |
407 | | - ) |
408 | | - elif isinstance(dos, PhononDos): |
409 | | - dos_dict[key] = dos # type: ignore[index] |
410 | | - else: |
| 416 | + total_overlay_dict: dict[str, PhononDos] = {} |
| 417 | + for label, raw_dos in raw_doses.items(): |
| 418 | + label_prefix = f"{label} - " if label else "" |
| 419 | + if project is None: |
| 420 | + cls_name = f"{type(raw_dos).__module__}.{type(raw_dos).__qualname__}" |
| 421 | + if cls_name == "phonopy.phonon.dos.TotalDos": |
| 422 | + phonopy_total_dos = cast("Any", raw_dos) |
| 423 | + dos_dict[label] = PhononDos( |
| 424 | + frequencies=phonopy_total_dos.frequency_points, |
| 425 | + densities=phonopy_total_dos.dos, |
| 426 | + ) |
| 427 | + elif isinstance(raw_dos, CompletePhononDos): |
| 428 | + dos_dict[label] = PhononDos(raw_dos.frequencies, raw_dos.densities) |
| 429 | + elif isinstance(raw_dos, PhononDos): |
| 430 | + dos_dict[label] = raw_dos |
| 431 | + else: |
| 432 | + raise TypeError( |
| 433 | + f"Only {PhononDos.__name__}, {CompletePhononDos.__name__}, " |
| 434 | + "phonopy TotalDos, or dict of these supported, " |
| 435 | + f"got {type(raw_dos).__name__}" |
| 436 | + ) |
| 437 | + continue |
| 438 | + if not isinstance(raw_dos, CompletePhononDos): |
411 | 439 | raise TypeError( |
412 | | - f"Only {PhononDos.__name__} or dict supported, got {type(dos).__name__}" |
| 440 | + f"project={project!r} requires CompletePhononDos, " |
| 441 | + f"got {type(raw_dos).__name__} for key {label!r}" |
413 | 442 | ) |
414 | | - if len(dos_dict) == 0: |
| 443 | + projected_dos = ( |
| 444 | + raw_dos.get_element_dos() |
| 445 | + if project == "element" |
| 446 | + else { |
| 447 | + f"{site.specie}{site_idx}": raw_dos.get_site_dos(site) |
| 448 | + for site_idx, site in enumerate(raw_dos.structure) |
| 449 | + } |
| 450 | + ) |
| 451 | + dos_dict |= {f"{label_prefix}{key}": dos for key, dos in projected_dos.items()} |
| 452 | + if show_total: |
| 453 | + total_overlay_dict[f"{label_prefix}Total"] = PhononDos( |
| 454 | + raw_dos.frequencies, raw_dos.densities |
| 455 | + ) |
| 456 | + |
| 457 | + if not dos_dict: |
415 | 458 | raise ValueError("Empty DOS dict") |
416 | 459 |
|
417 | 460 | if last_peak_anno == "": |
418 | 461 | last_peak_anno = "ω<sub>{key}</sub></span>={last_peak:.1f} {units}" |
419 | 462 |
|
420 | | - fig = go.Figure() |
421 | | - |
422 | | - for key, dos in dos_dict.items(): |
423 | | - frequencies = dos.frequencies |
| 463 | + def _prepare_dos(dos: PhononDos) -> tuple[np.ndarray, np.ndarray]: |
| 464 | + """Convert frequencies and apply smearing + normalization.""" |
| 465 | + frequencies = convert_frequencies(dos.frequencies, units) |
424 | 466 | densities = dos.get_smeared_densities(sigma) |
425 | | - |
426 | | - # convert frequencies to specified units |
427 | | - frequencies = convert_frequencies(frequencies, units) |
428 | | - |
429 | | - # normalize DOS |
430 | | - if normalize == "max": |
431 | | - densities /= densities.max() |
432 | | - elif normalize == "sum": |
433 | | - densities /= densities.sum() |
| 467 | + if normalize in ("max", "sum"): |
| 468 | + density_norm = densities.max() if normalize == "max" else densities.sum() |
| 469 | + if density_norm == 0: |
| 470 | + msg_key = "max density" if normalize == "max" else "sum density" |
| 471 | + raise ValueError( |
| 472 | + f"Cannot normalize DOS with mode={normalize!r}: {msg_key} is 0." |
| 473 | + ) |
| 474 | + densities = densities / density_norm |
434 | 475 | elif normalize == "integral": |
| 476 | + if len(frequencies) < 2: |
| 477 | + raise ValueError( |
| 478 | + "Cannot normalize DOS with mode='integral': " |
| 479 | + "need >=2 frequency points." |
| 480 | + ) |
435 | 481 | bin_width = frequencies[1] - frequencies[0] |
436 | | - densities = densities / densities.sum() / bin_width |
| 482 | + if bin_width == 0: |
| 483 | + raise ValueError( |
| 484 | + "Cannot normalize DOS with mode='integral': bin width is 0." |
| 485 | + ) |
| 486 | + density_norm = densities.sum() |
| 487 | + if density_norm == 0: |
| 488 | + raise ValueError( |
| 489 | + "Cannot normalize DOS with mode='integral': sum density is 0." |
| 490 | + ) |
| 491 | + densities = densities / density_norm / bin_width |
| 492 | + return frequencies, densities |
437 | 493 |
|
438 | | - scatter_defaults = dict(mode="lines") |
| 494 | + fig = go.Figure() |
| 495 | + cumulative_density_by_group: dict[str, np.ndarray] = {} |
| 496 | + for dos_name, dos_obj in dos_dict.items(): |
| 497 | + frequencies, densities = _prepare_dos(dos_obj) |
| 498 | + scatter_kwargs: dict[str, Any] = {"mode": "lines"} |
439 | 499 | if stack: |
440 | | - if fig.data: # for stacked plots, accumulate densities |
441 | | - densities += fig.data[-1].y |
442 | | - scatter_defaults.setdefault("fill", "tonexty") |
443 | | - |
| 500 | + stack_group = ( |
| 501 | + "" |
| 502 | + if project is None or " - " not in dos_name |
| 503 | + else dos_name.split(" - ", maxsplit=1)[0] |
| 504 | + ) |
| 505 | + densities = densities + cumulative_density_by_group.get( |
| 506 | + stack_group, np.zeros_like(densities) |
| 507 | + ) |
| 508 | + cumulative_density_by_group[stack_group] = densities |
| 509 | + scatter_kwargs["fill"] = "tonexty" |
444 | 510 | fig.add_scatter( |
445 | | - x=frequencies, y=densities, name=key, **scatter_defaults | kwargs |
| 511 | + x=frequencies, y=densities, name=dos_name, **scatter_kwargs | kwargs |
446 | 512 | ) |
447 | 513 |
|
| 514 | + if project is not None and show_total: |
| 515 | + for total_name, total_dos in total_overlay_dict.items(): |
| 516 | + frequencies, densities = _prepare_dos(total_dos) |
| 517 | + fig.add_scatter( |
| 518 | + x=frequencies, |
| 519 | + y=densities, |
| 520 | + name=total_name, |
| 521 | + mode="lines", |
| 522 | + line=dict(dash="dash", color="gray", width=1.5), |
| 523 | + showlegend=True, |
| 524 | + ) |
| 525 | + |
448 | 526 | fig.layout.xaxis.update(title=f"Frequency ({units})") |
449 | 527 | fig.layout.yaxis.update(title="Density of States", rangemode="tozero") |
450 | 528 | fig.layout.margin = dict(t=5, b=5, l=5, r=5) |
|
0 commit comments