|
4 | 4 |
|
5 | 5 | import sys |
6 | 6 | from collections import defaultdict |
7 | | -from collections.abc import Mapping |
| 7 | +from collections.abc import Callable, Mapping |
8 | 8 | from typing import TYPE_CHECKING, Any, Literal, cast |
9 | 9 |
|
10 | 10 | import numpy as np |
@@ -271,7 +271,10 @@ def phonon_bands( |
271 | 271 | # Apply line style based on line_kwargs type |
272 | 272 | if callable(line_kwargs): |
273 | 273 | # Pass band data and index to callback |
274 | | - custom_style = line_kwargs(frequencies, band_idx) |
| 274 | + line_style_callback = cast( |
| 275 | + "Callable[[np.ndarray, int], dict[str, Any]]", line_kwargs |
| 276 | + ) |
| 277 | + custom_style = line_style_callback(frequencies, band_idx) |
275 | 278 | line_defaults |= custom_style |
276 | 279 | elif isinstance(line_kwargs, dict): |
277 | 280 | # check for custom line styles for one or both modes |
@@ -413,6 +416,9 @@ def phonon_dos( |
413 | 416 | ) |
414 | 417 |
|
415 | 418 | dos_dict: dict[str, PhononDos] = {} |
| 419 | + stack_group_by_trace: dict[str, str] | None = ( |
| 420 | + {} if stack and project is not None else None |
| 421 | + ) |
416 | 422 | total_overlay_dict: dict[str, PhononDos] = {} |
417 | 423 | for label, raw_dos in raw_doses.items(): |
418 | 424 | label_prefix = f"{label} - " if label else "" |
@@ -448,7 +454,11 @@ def phonon_dos( |
448 | 454 | for site_idx, site in enumerate(raw_dos.structure) |
449 | 455 | } |
450 | 456 | ) |
451 | | - dos_dict |= {f"{label_prefix}{key}": dos for key, dos in projected_dos.items()} |
| 457 | + for key, dos in projected_dos.items(): |
| 458 | + trace_name = f"{label_prefix}{key}" |
| 459 | + dos_dict[trace_name] = dos |
| 460 | + if stack_group_by_trace is not None: |
| 461 | + stack_group_by_trace[trace_name] = label |
452 | 462 | if show_total: |
453 | 463 | total_overlay_dict[f"{label_prefix}Total"] = PhononDos( |
454 | 464 | raw_dos.frequencies, raw_dos.densities |
@@ -493,20 +503,27 @@ def _prepare_dos(dos: PhononDos) -> tuple[np.ndarray, np.ndarray]: |
493 | 503 |
|
494 | 504 | fig = go.Figure() |
495 | 505 | cumulative_density_by_group: dict[str, np.ndarray] = {} |
| 506 | + seen_stack_groups: set[str] = set() |
| 507 | + |
| 508 | + def _stack_group(trace_name: str) -> str: |
| 509 | + """Return stack accumulation group for this DOS trace.""" |
| 510 | + if project is None: |
| 511 | + return "" |
| 512 | + return stack_group_by_trace.get(trace_name, "") if stack_group_by_trace else "" |
| 513 | + |
496 | 514 | for dos_name, dos_obj in dos_dict.items(): |
497 | 515 | frequencies, densities = _prepare_dos(dos_obj) |
498 | 516 | scatter_kwargs: dict[str, Any] = {"mode": "lines"} |
499 | 517 | if stack: |
500 | | - stack_group = ( |
501 | | - "" |
502 | | - if project is None or " - " not in dos_name |
503 | | - else dos_name.split(" - ", maxsplit=1)[0] |
504 | | - ) |
| 518 | + stack_group = _stack_group(dos_name) |
505 | 519 | densities = densities + cumulative_density_by_group.get( |
506 | 520 | stack_group, np.zeros_like(densities) |
507 | 521 | ) |
508 | 522 | cumulative_density_by_group[stack_group] = densities |
509 | | - scatter_kwargs["fill"] = "tonexty" |
| 523 | + scatter_kwargs["fill"] = ( |
| 524 | + "tozeroy" if stack_group not in seen_stack_groups else "tonexty" |
| 525 | + ) |
| 526 | + seen_stack_groups.add(stack_group) |
510 | 527 | fig.add_scatter( |
511 | 528 | x=frequencies, y=densities, name=dos_name, **scatter_kwargs | kwargs |
512 | 529 | ) |
|
0 commit comments