Skip to content

Commit 0ed2233

Browse files
committed
add crystal_toolkit/apps/examples/relaxation_trajectory.py
1 parent 2819be9 commit 0ed2233

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pandas as pd
2+
import plotly.express as px
3+
import plotly.graph_objects as go
4+
from dash import Dash, dcc, html
5+
from dash.dependencies import Input, Output
6+
from pymatgen.core import Structure
7+
from pymatgen.ext.matproj import MPRester
8+
9+
import crystal_toolkit.components as ctc
10+
from crystal_toolkit.settings import SETTINGS
11+
12+
mp_id = "mp-1033715"
13+
with MPRester(monty_decode=False) as mpr:
14+
[task_doc] = mpr.tasks.search(task_ids=[mp_id])
15+
16+
17+
steps = [
18+
(Structure.from_dict(step["structure"]), step["e_fr_energy"])
19+
for calc in reversed(task_doc.calcs_reversed)
20+
for step in calc.output["ionic_steps"]
21+
]
22+
struct_traj, energies = zip(*steps)
23+
assert len(steps) == 99
24+
25+
e_col = "energy (eV/atom)"
26+
spg_col = "spacegroup"
27+
df_traj = pd.DataFrame(
28+
{e_col: energies, spg_col: [s.get_space_group_info() for s in struct_traj]}
29+
)
30+
31+
32+
def plot_energy(df: pd.DataFrame, step: int) -> go.Figure:
33+
"""Plot energy as a function of relaxation step."""
34+
href = f"https://materialsproject.org/materials/{mp_id}"
35+
title = f"<a {href=}>{mp_id}</a> - {spg_col} = {df[spg_col][step]}"
36+
fig = px.line(df, y=e_col, template="plotly_white", title=title)
37+
fig.add_vline(x=step, line=dict(dash="dash", width=1))
38+
return fig
39+
40+
41+
struct_comp = ctc.StructureMoleculeComponent(
42+
id="structure", struct_or_mol=struct_traj[0]
43+
)
44+
45+
step_size = max(1, len(struct_traj) // 20) # ensure slider has max 20 steps
46+
slider = dcc.Slider(
47+
id="slider",
48+
min=0,
49+
max=len(struct_traj) - 1,
50+
value=0,
51+
step=step_size,
52+
updatemode="drag",
53+
)
54+
55+
graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"})
56+
57+
app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
58+
app.layout = html.Div(
59+
[
60+
html.H1(
61+
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
62+
),
63+
html.P("Drag slider to see structure at different relaxation steps."),
64+
slider,
65+
html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
66+
],
67+
style=dict(
68+
margin="2em auto", placeItems="center", textAlign="center", maxWidth="1000px"
69+
),
70+
)
71+
72+
ctc.register_crystal_toolkit(app=app, layout=app.layout)
73+
74+
75+
@app.callback(
76+
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
77+
)
78+
def update_structure(step: int) -> tuple[Structure, go.Figure]:
79+
"""Update the structure displayed in the StructureMoleculeComponent and the
80+
dashed vertical line in the figure when the slider is moved.
81+
"""
82+
return struct_traj[step], plot_energy(df_traj, step)
83+
84+
85+
app.run_server(port=8050, debug=True)

crystal_toolkit/components/structure.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ class StructureMoleculeComponent(MPComponent):
8484

8585
def __init__(
8686
self,
87-
struct_or_mol: None
88-
| (Structure | StructureGraph | Molecule | MoleculeGraph) = None,
87+
struct_or_mol: (
88+
None | Structure | StructureGraph | Molecule | MoleculeGraph
89+
) = None,
8990
id: str = None,
9091
className: str = "box",
9192
scene_additions: Scene | None = None,

0 commit comments

Comments
 (0)