Skip to content

Commit e5f91d2

Browse files
authored
New example app: visualize structure relaxation (#323)
* add get_breadcrumb() doc str * add crystal_toolkit/apps/examples/relaxation_trajectory.py * add normed force mean over atoms line to relaxation plot clean up code and styles
1 parent e10319d commit e5f91d2

File tree

5 files changed

+164
-20
lines changed

5 files changed

+164
-20
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import sys
2+
3+
import numpy as np
4+
import pandas as pd
5+
import plotly.graph_objects as go
6+
from dash import Dash, dcc, html
7+
from dash.dependencies import Input, Output
8+
from pymatgen.core import Structure
9+
from pymatgen.ext.matproj import MPRester
10+
11+
import crystal_toolkit.components as ctc
12+
from crystal_toolkit.settings import SETTINGS
13+
14+
mp_id = "mp-1033715"
15+
with MPRester(monty_decode=False) as mpr:
16+
[task_doc] = mpr.tasks.search(task_ids=[mp_id])
17+
18+
steps = [
19+
(
20+
Structure.from_dict(step["structure"]),
21+
step["e_fr_energy"],
22+
np.linalg.norm(step["forces"], axis=1).mean(),
23+
)
24+
for calc in reversed(task_doc.calcs_reversed)
25+
for step in calc.output["ionic_steps"]
26+
]
27+
assert len(steps) == 99
28+
29+
e_col = "Energy (eV)"
30+
force_col = "Force (eV/Å)"
31+
spg_col = "Spacegroup"
32+
struct_col = "Structure"
33+
34+
df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col])
35+
df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info)
36+
37+
38+
def plot_energy_and_forces(
39+
df: pd.DataFrame,
40+
step: int,
41+
e_col: str,
42+
force_col: str,
43+
title: str,
44+
) -> go.Figure:
45+
"""Plot energy and forces as a function of relaxation step."""
46+
fig = go.Figure()
47+
# energy trace = primary y-axis
48+
fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))
49+
50+
# forces trace = secondary y-axis
51+
fig.add_trace(
52+
go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
53+
)
54+
55+
fig.update_layout(
56+
template="plotly_white",
57+
title=title,
58+
xaxis={"title": "Relaxation Step"},
59+
yaxis={"title": e_col},
60+
yaxis2={"title": force_col, "overlaying": "y", "side": "right"},
61+
legend=dict(yanchor="top", y=1, xanchor="right", x=1),
62+
)
63+
64+
# vertical line at the specified step
65+
fig.add_vline(x=step, line={"dash": "dash", "width": 1})
66+
67+
return fig
68+
69+
70+
if "struct_comp" not in locals():
71+
struct_comp = ctc.StructureMoleculeComponent(
72+
id="structure", struct_or_mol=df_traj[struct_col][0]
73+
)
74+
75+
step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps
76+
slider = dcc.Slider(
77+
id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag"
78+
)
79+
80+
81+
def make_title(spg: tuple[str, int]) -> str:
82+
"""Return a title for the figure."""
83+
href = f"https://materialsproject.org/materials/{mp_id}/"
84+
return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})"
85+
86+
87+
title = make_title(df_traj[spg_col][0])
88+
graph = dcc.Graph(
89+
id="fig",
90+
figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
91+
style={"maxWidth": "50%"},
92+
)
93+
94+
app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
95+
app.layout = html.Div(
96+
[
97+
html.H1(
98+
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
99+
),
100+
html.P("Drag slider to see structure at different relaxation steps."),
101+
slider,
102+
html.Div(
103+
[struct_comp.layout(), graph],
104+
style=dict(display="flex", gap="2em", placeContent="center"),
105+
),
106+
],
107+
style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"),
108+
)
109+
110+
ctc.register_crystal_toolkit(app=app, layout=app.layout)
111+
112+
113+
@app.callback(
114+
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
115+
)
116+
def update_structure(step: int) -> tuple[Structure, go.Figure]:
117+
"""Update the structure displayed in the StructureMoleculeComponent and the
118+
dashed vertical line in the figure when the slider is moved.
119+
"""
120+
title = make_title(df_traj[spg_col][step])
121+
fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)
122+
123+
return df_traj[struct_col][step], fig
124+
125+
126+
# https://stackoverflow.com/a/74918941
127+
is_jupyter = "ipykernel" in sys.modules
128+
129+
app.run(port=8050, debug=True, use_reloader=not is_jupyter)

crystal_toolkit/components/structure.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ class StructureMoleculeComponent(MPComponent):
8484

8585
def __init__(
8686
self,
87-
struct_or_mol: None
88-
| (Structure | StructureGraph | Molecule | MoleculeGraph) = None,
87+
struct_or_mol: (
88+
None | Structure | StructureGraph | Molecule | MoleculeGraph
89+
) = None,
8990
id: str = None,
9091
className: str = "box",
9192
scene_additions: Scene | None = None,

crystal_toolkit/helpers/layouts.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -439,20 +439,24 @@ def __init__(self, *args, **kwargs) -> None:
439439

440440

441441
def get_breadcrumb(parts):
442+
"""Create a breadcrumb navigation bar.
443+
444+
Args:
445+
parts (dict): Dictionary of name, link pairs.
446+
447+
Returns:
448+
html.Nav: Breadcrumb navigation bar.
449+
"""
442450
if not parts:
443-
return html.Div()
451+
return html.Nav()
444452

445-
breadcrumbs = html.Nav(
446-
html.Ul(
447-
[
448-
html.Li(
449-
dcc.Link(name, href=link),
450-
className=(None if idx != len(parts) - 1 else "is-active"),
451-
)
452-
for idx, (name, link) in enumerate(parts.items())
453-
]
454-
),
455-
className="breadcrumb",
456-
)
453+
links = [
454+
html.Li(
455+
dcc.Link(name, href=link),
456+
className="is-active" if idx == len(parts) - 1 else None,
457+
)
458+
for idx, (name, link) in enumerate(parts.items())
459+
]
460+
breadcrumbs = html.Nav(html.Ul(links), className="breadcrumb")
457461

458462
return breadcrumbs

crystal_toolkit/helpers/povray_renderer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@
135135

136136

137137
def pov_write_data(input_scene_comp, fstream):
138-
"""parse a primitive display object in crystaltoolkit and print it to POV-Ray input_scene_comp
139-
fstream.
138+
"""Parse a primitive display object in crystaltoolkit and print it to POV-Ray
139+
input_scene_comp fstream.
140140
"""
141141
vect = "{:.4f},{:.4f},{:.4f}"
142142

crystal_toolkit/renderables/volumetric.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import numpy as np
6+
from numpy.typing import ArrayLike
47
from pymatgen.io.vasp import VolumetricData
58

69
from crystal_toolkit.core.scene import Scene, Surface
@@ -9,15 +12,22 @@
912

1013

1114
def get_isosurface_scene(
12-
self, data_key="total", isolvl=0.05, step_size=4, origin=None, **kwargs
13-
):
15+
self,
16+
data_key: str = "total",
17+
isolvl: float = 0.05,
18+
step_size: int = 4,
19+
origin: ArrayLike = None,
20+
**kwargs: Any,
21+
) -> Scene:
1422
"""Get the isosurface from a VolumetricData object.
1523
1624
Args:
1725
data_key (str, optional): Use the volumetric data from self.data[data_key]. Defaults to 'total'.
1826
isolvl (float, optional): The cutoff for the isosurface to using the same units as VESTA so
19-
e/bohr and kept grid size independent
27+
e/bohr and kept grid size independent
2028
step_size (int, optional): step_size parameter for marching_cubes_lewiner. Defaults to 3.
29+
origin (ArrayLike, optional): The origin of the isosurface. Defaults to None.
30+
**kwargs: Passed to the Surface object.
2131
2232
Returns:
2333
Scene: object containing the isosurface component

0 commit comments

Comments
 (0)