Skip to content

Commit f5fe3da

Browse files
committed
add crystal_toolkit/apps/examples/relaxation_trajectory.py
1 parent 2819be9 commit f5fe3da

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import pandas as pd
2+
import plotly.express as px
3+
import plotly.graph_objects as go
4+
from dash import Dash, dcc, html
5+
from dash.dependencies import Input, Output
6+
from pymatgen.core import Structure
7+
from pymatgen.ext.matproj import MPRester
8+
9+
import crystal_toolkit.components as ctc
10+
from crystal_toolkit.settings import SETTINGS
11+
12+
mp_id = "mp-1033715"
13+
with MPRester(monty_decode=False) as mpr:
14+
[task_doc] = mpr.tasks.search(task_ids=[mp_id])
15+
16+
17+
struct_traj = [
18+
Structure.from_dict(step["structure"])
19+
for calc in reversed(task_doc.calcs_reversed)
20+
for step in calc.output["ionic_steps"]
21+
]
22+
assert len(struct_traj) == 99
23+
energies = [
24+
step["e_fr_energy"]
25+
for calc in task_doc.calcs_reversed
26+
for step in calc.output["ionic_steps"]
27+
]
28+
assert len(energies) == len(struct_traj)
29+
30+
e_col = "energy (eV/atom)"
31+
spg_col = "spacegroup"
32+
df_traj = pd.DataFrame(
33+
{e_col: energies, spg_col: [s.get_space_group_info() for s in struct_traj]}
34+
)
35+
36+
struct_comp = ctc.StructureMoleculeComponent(
37+
id="structure", struct_or_mol=struct_traj[0]
38+
)
39+
40+
41+
step_size = max(1, len(struct_traj) // 20) # ensure slider has max 20 steps
42+
slider = dcc.Slider(
43+
id="slider",
44+
min=0,
45+
max=len(struct_traj) - 1,
46+
value=0,
47+
step=step_size,
48+
updatemode="drag",
49+
)
50+
51+
52+
def plot_energy(df: pd.DataFrame, step: int) -> go.Figure:
53+
"""Plot energy as a function of relaxation step."""
54+
title = f"{mp_id} - {spg_col} = {df[spg_col][step]}"
55+
fig = px.line(df, y=e_col, template="plotly_white", title=title)
56+
fig.add_vline(x=step, line=dict(dash="dash", width=1))
57+
return fig
58+
59+
60+
graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"})
61+
62+
app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
63+
app.layout = html.Div(
64+
[
65+
html.H1(
66+
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
67+
),
68+
html.P("Drag slider to see structure at different relaxation steps."),
69+
slider,
70+
html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
71+
],
72+
style=dict(
73+
margin="2em auto", placeItems="center", textAlign="center", maxWidth="1000px"
74+
),
75+
)
76+
77+
ctc.register_crystal_toolkit(app=app, layout=app.layout)
78+
79+
80+
@app.callback(
81+
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
82+
)
83+
def update_structure(step: int) -> tuple[Structure, go.Figure]:
84+
"""Update the structure displayed in the StructureMoleculeComponent and the
85+
dashed vertical line in the figure when the slider is moved.
86+
"""
87+
return struct_traj[step], plot_energy(df_traj, step)
88+
89+
90+
app.run_server(port=8050, debug=True)

0 commit comments

Comments
 (0)