|
| 1 | +import pandas as pd |
| 2 | +import plotly.express as px |
| 3 | +import plotly.graph_objects as go |
| 4 | +from dash import Dash, dcc, html |
| 5 | +from dash.dependencies import Input, Output |
| 6 | +from pymatgen.core import Structure |
| 7 | +from pymatgen.ext.matproj import MPRester |
| 8 | + |
| 9 | +import crystal_toolkit.components as ctc |
| 10 | +from crystal_toolkit.settings import SETTINGS |
| 11 | + |
| 12 | +mp_id = "mp-1033715" |
| 13 | +with MPRester(monty_decode=False) as mpr: |
| 14 | + [task_doc] = mpr.tasks.search(task_ids=[mp_id]) |
| 15 | + |
| 16 | + |
| 17 | +steps = [ |
| 18 | + (Structure.from_dict(step["structure"]), step["e_fr_energy"]) |
| 19 | + for calc in reversed(task_doc.calcs_reversed) |
| 20 | + for step in calc.output["ionic_steps"] |
| 21 | +] |
| 22 | +struct_traj, energies = zip(*steps) |
| 23 | +assert len(steps) == 99 |
| 24 | + |
| 25 | +e_col = "energy (eV/atom)" |
| 26 | +spg_col = "spacegroup" |
| 27 | +df_traj = pd.DataFrame( |
| 28 | + {e_col: energies, spg_col: [s.get_space_group_info() for s in struct_traj]} |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +def plot_energy(df: pd.DataFrame, step: int) -> go.Figure: |
| 33 | + """Plot energy as a function of relaxation step.""" |
| 34 | + href = f"https://materialsproject.org/materials/{mp_id}" |
| 35 | + title = f"<a {href=}>{mp_id}</a> - {spg_col} = {df[spg_col][step]}" |
| 36 | + fig = px.line(df, y=e_col, template="plotly_white", title=title) |
| 37 | + fig.add_vline(x=step, line=dict(dash="dash", width=1)) |
| 38 | + return fig |
| 39 | + |
| 40 | + |
| 41 | +struct_comp = ctc.StructureMoleculeComponent( |
| 42 | + id="structure", struct_or_mol=struct_traj[0] |
| 43 | +) |
| 44 | + |
| 45 | +step_size = max(1, len(struct_traj) // 20) # ensure slider has max 20 steps |
| 46 | +slider = dcc.Slider( |
| 47 | + id="slider", |
| 48 | + min=0, |
| 49 | + max=len(struct_traj) - 1, |
| 50 | + value=0, |
| 51 | + step=step_size, |
| 52 | + updatemode="drag", |
| 53 | +) |
| 54 | + |
| 55 | +graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"}) |
| 56 | + |
| 57 | +app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH) |
| 58 | +app.layout = html.Div( |
| 59 | + [ |
| 60 | + html.H1( |
| 61 | + "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em") |
| 62 | + ), |
| 63 | + html.P("Drag slider to see structure at different relaxation steps."), |
| 64 | + slider, |
| 65 | + html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")), |
| 66 | + ], |
| 67 | + style=dict( |
| 68 | + margin="2em auto", placeItems="center", textAlign="center", maxWidth="1000px" |
| 69 | + ), |
| 70 | +) |
| 71 | + |
| 72 | +ctc.register_crystal_toolkit(app=app, layout=app.layout) |
| 73 | + |
| 74 | + |
| 75 | +@app.callback( |
| 76 | + Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value") |
| 77 | +) |
| 78 | +def update_structure(step: int) -> tuple[Structure, go.Figure]: |
| 79 | + """Update the structure displayed in the StructureMoleculeComponent and the |
| 80 | + dashed vertical line in the figure when the slider is moved. |
| 81 | + """ |
| 82 | + return struct_traj[step], plot_energy(df_traj, step) |
| 83 | + |
| 84 | + |
| 85 | +app.run_server(port=8050, debug=True) |
0 commit comments