Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions crystal_toolkit/apps/examples/relaxation_trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import sys

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester

import crystal_toolkit.components as ctc
from crystal_toolkit.settings import SETTINGS

mp_id = "mp-1033715"
with MPRester(monty_decode=False) as mpr:
[task_doc] = mpr.tasks.search(task_ids=[mp_id])

steps = [
(
Structure.from_dict(step["structure"]),
step["e_fr_energy"],
np.linalg.norm(step["forces"], axis=1).mean(),
)
for calc in reversed(task_doc.calcs_reversed)
for step in calc.output["ionic_steps"]
]
assert len(steps) == 99

e_col = "Energy (eV)"
force_col = "Force (eV/Å)"
spg_col = "Spacegroup"
struct_col = "Structure"

df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col])
df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info)


def plot_energy_and_forces(
df: pd.DataFrame,
step: int,
e_col: str,
force_col: str,
title: str,
) -> go.Figure:
"""Plot energy and forces as a function of relaxation step."""
fig = go.Figure()
# energy trace = primary y-axis
fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))

# forces trace = secondary y-axis
fig.add_trace(
go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
)

fig.update_layout(
template="plotly_white",
title=title,
xaxis={"title": "Relaxation Step"},
yaxis={"title": e_col},
yaxis2={"title": force_col, "overlaying": "y", "side": "right"},
legend=dict(yanchor="top", y=1, xanchor="right", x=1),
)

# vertical line at the specified step
fig.add_vline(x=step, line={"dash": "dash", "width": 1})

return fig


if "struct_comp" not in locals():
struct_comp = ctc.StructureMoleculeComponent(
id="structure", struct_or_mol=df_traj[struct_col][0]
)

step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps
slider = dcc.Slider(
id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag"
)


def make_title(spg: tuple[str, int]) -> str:
"""Return a title for the figure."""
href = f"https://materialsproject.org/materials/{mp_id}/"
return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})"


title = make_title(df_traj[spg_col][0])
graph = dcc.Graph(
id="fig",
figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
style={"maxWidth": "50%"},
)

app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
app.layout = html.Div(
[
html.H1(
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
),
html.P("Drag slider to see structure at different relaxation steps."),
slider,
html.Div(
[struct_comp.layout(), graph],
style=dict(display="flex", gap="2em", placeContent="center"),
),
],
style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
"""Update the structure displayed in the StructureMoleculeComponent and the
dashed vertical line in the figure when the slider is moved.
"""
title = make_title(df_traj[spg_col][step])
fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)

return df_traj[struct_col][step], fig


# https://stackoverflow.com/a/74918941
is_jupyter = "ipykernel" in sys.modules

app.run(port=8050, debug=True, use_reloader=not is_jupyter)
5 changes: 3 additions & 2 deletions crystal_toolkit/components/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ class StructureMoleculeComponent(MPComponent):

def __init__(
self,
struct_or_mol: None
| (Structure | StructureGraph | Molecule | MoleculeGraph) = None,
struct_or_mol: (
None | Structure | StructureGraph | Molecule | MoleculeGraph
) = None,
id: str = None,
className: str = "box",
scene_additions: Scene | None = None,
Expand Down
30 changes: 17 additions & 13 deletions crystal_toolkit/helpers/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,20 +439,24 @@ def __init__(self, *args, **kwargs) -> None:


def get_breadcrumb(parts):
"""Create a breadcrumb navigation bar.

Args:
parts (dict): Dictionary of name, link pairs.

Returns:
html.Nav: Breadcrumb navigation bar.
"""
if not parts:
return html.Div()
return html.Nav()

breadcrumbs = html.Nav(
html.Ul(
[
html.Li(
dcc.Link(name, href=link),
className=(None if idx != len(parts) - 1 else "is-active"),
)
for idx, (name, link) in enumerate(parts.items())
]
),
className="breadcrumb",
)
links = [
html.Li(
dcc.Link(name, href=link),
className="is-active" if idx == len(parts) - 1 else None,
)
for idx, (name, link) in enumerate(parts.items())
]
breadcrumbs = html.Nav(html.Ul(links), className="breadcrumb")

return breadcrumbs
4 changes: 2 additions & 2 deletions crystal_toolkit/helpers/povray_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@


def pov_write_data(input_scene_comp, fstream):
"""parse a primitive display object in crystaltoolkit and print it to POV-Ray input_scene_comp
fstream.
"""Parse a primitive display object in crystaltoolkit and print it to POV-Ray
input_scene_comp fstream.
"""
vect = "{:.4f},{:.4f},{:.4f}"

Expand Down
16 changes: 13 additions & 3 deletions crystal_toolkit/renderables/volumetric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Any

import numpy as np
from numpy.typing import ArrayLike
from pymatgen.io.vasp import VolumetricData

from crystal_toolkit.core.scene import Scene, Surface
Expand All @@ -9,15 +12,22 @@


def get_isosurface_scene(
self, data_key="total", isolvl=0.05, step_size=4, origin=None, **kwargs
):
self,
data_key: str = "total",
isolvl: float = 0.05,
step_size: int = 4,
origin: ArrayLike = None,
**kwargs: Any,
) -> Scene:
"""Get the isosurface from a VolumetricData object.

Args:
data_key (str, optional): Use the volumetric data from self.data[data_key]. Defaults to 'total'.
isolvl (float, optional): The cutoff for the isosurface to using the same units as VESTA so
e/bohr and kept grid size independent
e/bohr and kept grid size independent
step_size (int, optional): step_size parameter for marching_cubes_lewiner. Defaults to 3.
origin (ArrayLike, optional): The origin of the isosurface. Defaults to None.
**kwargs: Passed to the Surface object.

Returns:
Scene: object containing the isosurface component
Expand Down