Skip to content

Commit 485ed72

Browse files
committed
Add widget state export helpers and harden projected DOS tests (#340)
* Add widget display and state export helpers - add `display()` convenience method to `MatterVizWidget` for notebook usage - add `to_dict()` to export stable public widget state fields - add parameterized tests to verify exported keys and runtime state updates * fix ty errors * Improve widget normalization and projected DOS stacking behavior - add strict XRD normalization support for canonical and Ferrox dict schemas with validation and hkls normalization - enforce exactly one Fermi surface data source and strengthen widget tests by asserting meaningful normalized state - fix projected DOS stacked fill baseline per model group and update metallic-glass example docs/output formatting for optional values * Refine XRD schema precedence and test readability - detect canonical and Ferrox schemas only when all required keys are present, and keep targeted errors for partial schemas - add regression coverage for mixed-key XRD dicts that should resolve through complete Ferrox schema - simplify projected DOS stack reset test by reusing one located trace per model label * Harden trajectory dict handling and normalize Ferrox HKLs - extract Ferrox HKL normalization into a helper for cleaner XRD input handling - validate trajectory dict schemas with clearer type/key errors while preserving legacy structure-frame dicts - add and refine trajectory widget regression tests for invalid schemas and restored lifecycle state * Harden trajectory normalization tests - validate and complete trajectory dict inputs more robustly, including per-site coordinate checks and safer species handling - simplify widget tests by removing tautological/perf-brittle assertions and replacing them with stronger behavioral checks * fix ty + update readme citations
1 parent 4bb4653 commit 485ed72

37 files changed

+1031
-303
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ default_install_hook_types: [pre-commit, commit-msg]
22

33
repos:
44
- repo: https://github.com/astral-sh/ruff-pre-commit
5-
rev: v0.15.0
5+
rev: v0.15.2
66
hooks:
77
- id: ruff-check
88
args: [--fix]
@@ -14,11 +14,11 @@ repos:
1414
hooks:
1515
- id: ty
1616
name: ty check
17-
entry: ty check
17+
entry: ty check --ignore unused-type-ignore-comment --ignore unused-ignore-comment
1818
language: python
1919
types: [python]
2020
pass_filenames: false
21-
additional_dependencies: [ty>=0.0.15]
21+
additional_dependencies: [ty==0.0.18]
2222
- id: check-readme-src-links
2323
name: Check README source links
2424
entry: python .github/scripts/check_readme_src_links.py
@@ -51,10 +51,10 @@ repos:
5151
- id: codespell
5252
stages: [pre-commit, commit-msg]
5353
exclude_types: [csv, svg, html, yaml, jupyter]
54-
args: [--ignore-words-list, 'hist,mape,te,nd,fpr', --check-filenames]
54+
args: [--ignore-words-list, 'hist,mape,te,nd,fpr,abou,nam', --check-filenames]
5555

5656
- repo: https://github.com/kynan/nbstripout
57-
rev: 0.9.0
57+
rev: 0.9.1
5858
hooks:
5959
- id: nbstripout
6060
args: [--drop-empty-cells, --keep-output]
@@ -81,7 +81,7 @@ repos:
8181
- --fix
8282

8383
- repo: https://github.com/pre-commit/mirrors-eslint
84-
rev: v10.0.0
84+
rev: v10.0.1
8585
hooks:
8686
- id: eslint
8787
types: [file]
@@ -96,7 +96,7 @@ repos:
9696
- '@stylistic/eslint-plugin'
9797

9898
- repo: https://github.com/python-jsonschema/check-jsonschema
99-
rev: 0.36.1
99+
rev: 0.36.2
100100
hooks:
101101
- id: check-jsonschema
102102
files: ^pymatviz/keys\.yml$

assets/scripts/cluster/composition/cluster_compositions_matbench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def annotate_top_points(row: pd.Series) -> dict[str, Any] | None:
230230
"Bulk Modulus (GPa)",
231231
"K<sub>VRH</sub>",
232232
)
233-
plot_combinations: list[PlotConfig] = [ # ty: ignore[invalid-assignment]
233+
plot_combinations: list[PlotConfig] = [
234234
# 1. Steels with PCA (2D) - shows clear linear trends
235235
(*mb_steels, Embed.magpie, Project.pca, 2, {"x": 0.01, "xanchor": "left"}),
236236
# 2. Steels with t-SNE (2D) - shows non-linear clustering

assets/scripts/phonons/phonon_dos.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ class PhonopyDosMissingError(RuntimeError):
2323
"""Raised when phonopy fails to compute a required DOS output."""
2424

2525

26+
class PhonopyTotalDosMissingError(PhonopyDosMissingError):
27+
"""Raised when phonopy total DOS is missing."""
28+
29+
30+
class PhonopyProjectedDosMissingError(PhonopyDosMissingError):
31+
"""Raised when phonopy projected DOS is missing."""
32+
33+
2634
def show_figure(plotly_figure: go.Figure, title: str, *, y_pos: float = 0.97) -> None:
2735
"""Apply consistent layout settings and display the figure."""
2836
plotly_figure.layout.title = dict(text=title, x=0.5, y=y_pos)
@@ -59,7 +67,7 @@ def show_figure(plotly_figure: go.Figure, title: str, *, y_pos: float = 0.97) ->
5967
phonopy_nacl.run_mesh([10, 10, 10])
6068
phonopy_nacl.run_total_dos()
6169
if phonopy_nacl.total_dos is None:
62-
raise PhonopyDosMissingError
70+
raise PhonopyTotalDosMissingError
6371

6472
plt = phonopy_nacl.plot_total_dos()
6573
plt.title("NaCl DOS plotted by phonopy")
@@ -75,9 +83,9 @@ def show_figure(plotly_figure: go.Figure, title: str, *, y_pos: float = 0.97) ->
7583
phonopy_nacl_pdos.run_projected_dos()
7684
phonopy_nacl_pdos.run_total_dos()
7785
if phonopy_nacl_pdos.total_dos is None:
78-
raise PhonopyDosMissingError
86+
raise PhonopyTotalDosMissingError
7987
if phonopy_nacl_pdos.projected_dos is None:
80-
raise PhonopyDosMissingError
88+
raise PhonopyProjectedDosMissingError
8189

8290
struct = get_pmg_structure(phonopy_nacl_pdos.primitive)
8391
total_dos = PhononDos(

assets/scripts/track_pymatviz_citations.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ def save_papers(papers: list[ScholarPaper], filename: str) -> None:
232232
yaml.dump(papers, file, default_flow_style=False, allow_unicode=True)
233233

234234

235+
def clean_author_name(author_name: str) -> str:
236+
"""Remove footnote/superscript markers from an author name."""
237+
cleaned_name = re.sub(r"\^[0-9]+", "", author_name)
238+
cleaned_name = re.sub(r"[⁰¹²³⁴⁵⁶⁷⁸⁹]+", "", cleaned_name)
239+
return re.sub(r"\s+", " ", cleaned_name).strip()
240+
241+
235242
def update_readme(
236243
papers: list[ScholarPaper], readme_path: str = f"{ROOT}/readme.md"
237244
) -> None:
@@ -270,7 +277,9 @@ def update_readme(
270277
if not paper.get("authors"):
271278
print(f"Paper {paper['title']} has no authors, skipping")
272279
continue
273-
authors_str = ", ".join(paper["authors"][:3])
280+
authors_str = ", ".join(
281+
clean_author_name(author) for author in paper["authors"][:3]
282+
)
274283
if len(paper["authors"]) > 3:
275284
authors_str += " et al."
276285

examples/ward_metallic_glasses/formula_features.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def calc_reduced_binary_liquidus_temp(
111111
binary_interpolations: dict[str, interp1d],
112112
*,
113113
on_key_err: Literal["raise", "set-none"] = "set-none",
114-
) -> float:
114+
) -> float | None:
115115
"""Calculate the reduced average binary liquidus temperature for a general alloy.
116116
117117
NOTE the unary melting points from the tabulated data are not used here as
@@ -127,11 +127,11 @@ def calc_reduced_binary_liquidus_temp(
127127
on_key_err ("raise" | "set-none"): How to handle missing binary
128128
systems.
129129
If "raise", raises KeyError. If "set-none", returns None.
130-
Defaults to "raise".
130+
Defaults to "set-none".
131131
132132
Returns:
133-
float: The reduced binary liquidus temperature or None if on_key_err="set-none"
134-
and a binary system is missing.
133+
float | None: Reduced binary liquidus temperature or None if
134+
on_key_err="set-none" and a binary system is missing.
135135
"""
136136
if len(composition) < 2:
137137
return 1.0
@@ -149,7 +149,7 @@ def calc_reduced_binary_liquidus_temp(
149149
except KeyError:
150150
if on_key_err == "raise":
151151
raise
152-
return None # type: ignore[return-value]
152+
return None
153153
binary_weight = sum(comp_dict.values())
154154
temp_alloy += temp_binary * binary_weight
155155
temp_alloy_norm += binary_weight
@@ -335,13 +335,15 @@ def one_hot_encode(df_in: pd.DataFrame) -> pd.DataFrame:
335335
binary_interpolations = load_binary_liquidus_data(zip_path)
336336

337337
# Test with a simple binary composition
338-
test_comp = Composition("Pt50P50")
338+
test_comp = "Pt50P50"
339339
features = calc_liu_features(
340340
[test_comp], binary_liquidus_data=binary_interpolations
341341
)
342342
print("\nFeatures for Pt50P50:")
343343
for feature, values in features.items():
344-
print(f"{feature}: {values[test_comp]:.2f}") # type: ignore[index]
344+
feature_value = values.get(test_comp)
345+
value_text = "N/A" if feature_value is None else f"{feature_value:.2f}"
346+
print(f"{feature}: {value_text}")
345347

346348
# Test with a more complex composition
347349
test_comp2 = "Zr6.2Ti45.8Cu39.9Ni5.1Sn3"
@@ -350,7 +352,9 @@ def one_hot_encode(df_in: pd.DataFrame) -> pd.DataFrame:
350352
)
351353
print(f"\nFeatures for {test_comp2}:")
352354
for feature, values in features2.items():
353-
print(f"{feature}: {values[test_comp2]:.2f}")
355+
feature_value = values.get(test_comp2)
356+
value_text = "N/A" if feature_value is None else f"{feature_value:.2f}"
357+
print(f"{feature}: {value_text}")
354358

355359
# Example of batch processing with a DataFrame
356360
df_test = pd.DataFrame(

pymatviz/brillouin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def brillouin_zone_3d(
8585

8686
for idx, (struct_key, structure) in enumerate(structures.items(), start=1):
8787
# Convert pymatgen Structure to seekpath input format
88-
lattice = structure.lattice # ty: ignore[possibly-missing-attribute]
89-
frac_coords = structure.frac_coords # ty: ignore[possibly-missing-attribute]
88+
lattice = structure.lattice
89+
frac_coords = structure.frac_coords
9090
spglib_atoms = (
9191
lattice.matrix, # cell
9292
frac_coords, # positions

pymatviz/classify/confusion_matrix.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Confusion matrix plotting functions."""
22

33
from collections.abc import Callable, Sequence
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, cast
55

66
import numpy as np
77
import plotly.graph_objects as go
@@ -133,6 +133,7 @@ def confusion_matrix(
133133
if annotations is None:
134134
processed_annotations = fmt_tile_vals
135135
elif callable(annotations): # If annotations is a callable, apply it to each cell
136+
annotation_func = cast("Callable[[int, int, float, float], str]", annotations)
136137
total = sample_counts.sum()
137138
anno_matrix = []
138139
for ii in range(len(conf_mat_arr)):
@@ -150,7 +151,7 @@ def confusion_matrix(
150151
if conf_mat_arr[:, jj].sum() > 0
151152
else 0
152153
)
153-
row += [annotations(count, total, row_pct, col_pct)]
154+
row += [annotation_func(count, total, row_pct, col_pct)]
154155
anno_matrix += [row]
155156
processed_annotations = np.array(anno_matrix).T
156157
else: # When custom annotations provided, append percentage values

pymatviz/cluster/composition/plot.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import warnings
77
from collections.abc import Callable, Sequence
8-
from typing import TYPE_CHECKING, Any, Literal, Protocol, get_args
8+
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast, get_args
99

1010
import numpy as np
1111
import pandas as pd
@@ -609,7 +609,8 @@ def cluster_compositions(
609609
# Create embeddings
610610
if callable(embedding_method):
611611
# Use custom embedding function
612-
embeddings = embedding_method(compositions, **(embedding_kwargs or {}))
612+
embedding_fn = cast("EmbeddingCallable", embedding_method)
613+
embeddings = embedding_fn(compositions, **(embedding_kwargs or {}))
613614
# Use built-in embedding methods
614615
elif embedding_method == "one-hot":
615616
embeddings = one_hot_encode(compositions, **(embedding_kwargs or {}))
@@ -628,7 +629,8 @@ def cluster_compositions(
628629
# Project embeddings
629630
if callable(projection):
630631
# Use custom projection function
631-
projected = projection(
632+
projection_fn = cast("ProjectionCallable", projection)
633+
projected = projection_fn(
632634
embeddings,
633635
n_components=n_components,
634636
**projection_kwargs,

pymatviz/phonons/figures.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import sys
66
from collections import defaultdict
7-
from collections.abc import Mapping
7+
from collections.abc import Callable, Mapping
88
from typing import TYPE_CHECKING, Any, Literal, cast
99

1010
import numpy as np
@@ -271,7 +271,10 @@ def phonon_bands(
271271
# Apply line style based on line_kwargs type
272272
if callable(line_kwargs):
273273
# Pass band data and index to callback
274-
custom_style = line_kwargs(frequencies, band_idx)
274+
line_style_callback = cast(
275+
"Callable[[np.ndarray, int], dict[str, Any]]", line_kwargs
276+
)
277+
custom_style = line_style_callback(frequencies, band_idx)
275278
line_defaults |= custom_style
276279
elif isinstance(line_kwargs, dict):
277280
# check for custom line styles for one or both modes
@@ -413,6 +416,9 @@ def phonon_dos(
413416
)
414417

415418
dos_dict: dict[str, PhononDos] = {}
419+
stack_group_by_trace: dict[str, str] | None = (
420+
{} if stack and project is not None else None
421+
)
416422
total_overlay_dict: dict[str, PhononDos] = {}
417423
for label, raw_dos in raw_doses.items():
418424
label_prefix = f"{label} - " if label else ""
@@ -448,7 +454,11 @@ def phonon_dos(
448454
for site_idx, site in enumerate(raw_dos.structure)
449455
}
450456
)
451-
dos_dict |= {f"{label_prefix}{key}": dos for key, dos in projected_dos.items()}
457+
for key, dos in projected_dos.items():
458+
trace_name = f"{label_prefix}{key}"
459+
dos_dict[trace_name] = dos
460+
if stack_group_by_trace is not None:
461+
stack_group_by_trace[trace_name] = label
452462
if show_total:
453463
total_overlay_dict[f"{label_prefix}Total"] = PhononDos(
454464
raw_dos.frequencies, raw_dos.densities
@@ -493,20 +503,27 @@ def _prepare_dos(dos: PhononDos) -> tuple[np.ndarray, np.ndarray]:
493503

494504
fig = go.Figure()
495505
cumulative_density_by_group: dict[str, np.ndarray] = {}
506+
seen_stack_groups: set[str] = set()
507+
508+
def _stack_group(trace_name: str) -> str:
509+
"""Return stack accumulation group for this DOS trace."""
510+
if project is None:
511+
return ""
512+
return stack_group_by_trace.get(trace_name, "") if stack_group_by_trace else ""
513+
496514
for dos_name, dos_obj in dos_dict.items():
497515
frequencies, densities = _prepare_dos(dos_obj)
498516
scatter_kwargs: dict[str, Any] = {"mode": "lines"}
499517
if stack:
500-
stack_group = (
501-
""
502-
if project is None or " - " not in dos_name
503-
else dos_name.split(" - ", maxsplit=1)[0]
504-
)
518+
stack_group = _stack_group(dos_name)
505519
densities = densities + cumulative_density_by_group.get(
506520
stack_group, np.zeros_like(densities)
507521
)
508522
cumulative_density_by_group[stack_group] = densities
509-
scatter_kwargs["fill"] = "tonexty"
523+
scatter_kwargs["fill"] = (
524+
"tozeroy" if stack_group not in seen_stack_groups else "tonexty"
525+
)
526+
seen_stack_groups.add(stack_group)
510527
fig.add_scatter(
511528
x=frequencies, y=densities, name=dos_name, **scatter_kwargs | kwargs
512529
)

pymatviz/powerups.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from collections.abc import Callable, Sequence
6-
from typing import TYPE_CHECKING, Any, Literal, get_args
6+
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
77

88
import numpy as np
99
import plotly.express as px
@@ -219,7 +219,10 @@ def _get_valid_traces(
219219
valid_range = f"0-{len(fig.data) - 1}"
220220
raise ValueError(f"No valid trace indices in {traces}, {valid_range=}")
221221
elif callable(traces):
222-
selected_traces = [idx for idx, trace in enumerate(fig.data) if traces(trace)]
222+
trace_predicate = cast("TracePredicate", traces)
223+
selected_traces = [
224+
idx for idx, trace in enumerate(fig.data) if trace_predicate(trace)
225+
]
223226
if not selected_traces:
224227
raise ValueError("No traces matched the filtering function")
225228
else:
@@ -884,10 +887,18 @@ def validate_ecdf_trace(trace: go.Scatter) -> bool:
884887
fig.data[-1].update(**current_trace_kwargs)
885888

886889
# Make sure yaxis2 has color set if specified in trace_kwargs
887-
if "line" in trace_kwargs and "color" in trace_kwargs["line"]:
888-
fig.layout.yaxis2.color = trace_kwargs["line"]["color"]
889-
elif "line_color" in trace_kwargs:
890-
fig.layout.yaxis2.color = trace_kwargs["line_color"]
890+
yaxis2_color: str | None = None
891+
line_settings = trace_kwargs.get("line")
892+
if isinstance(line_settings, dict):
893+
line_color = line_settings.get("color")
894+
if isinstance(line_color, str):
895+
yaxis2_color = line_color
896+
if yaxis2_color is None:
897+
top_level_line_color = trace_kwargs.get("line_color")
898+
if isinstance(top_level_line_color, str):
899+
yaxis2_color = top_level_line_color
900+
if yaxis2_color is not None:
901+
fig.layout.yaxis2.update(color=yaxis2_color)
891902

892903
return fig
893904

0 commit comments

Comments
 (0)