From 56497f21885ce9b8111aa0ad69be841ac6fa878d Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 15:18:31 +0800 Subject: [PATCH 1/8] add atom index to _meta --- crystal_toolkit/core/scene.py | 21 ++++++++++++------- crystal_toolkit/renderables/site.py | 5 +++++ crystal_toolkit/renderables/structuregraph.py | 20 ++++++++++++++++++ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 392c4d49..01b6f7f1 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -67,13 +67,13 @@ def __add__(self, other): lattice=self.lattice, _meta={self.name: self._meta, other.name: other._meta}, ) - - def _repr_mimebundle_(self, include=None, exclude=None): - """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" - return { - "application/vnd.mp.ctk+json": self.to_json(), - "text/plain": repr(self), - } + + # def _repr_mimebundle_(self, include=None, exclude=None): + # """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" + # return { + # "application/vnd.mp.ctk+json": self.to_json(), + # "text/plain": repr(self), + # } def to_json(self): """Convert a Scene into JSON. It will implicitly assume all None values means that attribute @@ -149,7 +149,6 @@ def merge_primitives(primitives): """ mergeable = defaultdict(list) remainder = [] - for primitive in primitives: if isinstance(primitive, Scene): primitive.contents = Scene.merge_primitives(primitive.contents) @@ -214,6 +213,7 @@ def merge(cls, sphere_list): visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, + _meta=sphere_list[0]._meta, ) @@ -320,6 +320,10 @@ def merge(cls, cylinder_list): chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) + new_meta_list = list( + chain.from_iterable([[cylinder._meta] for cylinder in cylinder_list]) + ) + return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, @@ -327,6 +331,7 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, + _meta=new_meta_list ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index ef43be57..d02aa95e 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,6 +135,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, + _meta=[site_idx] ) atoms.append(sphere) @@ -207,6 +208,7 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -218,6 +220,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) @@ -228,6 +231,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -251,6 +255,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index 6c8077c2..5c7abb25 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -197,6 +197,8 @@ def get_weight_color(weight): explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, bond_radius=bond_radius, + site_idx=idx, + show_atom_idx=True, **(site_get_scene_kwargs or {}), ) @@ -217,6 +219,23 @@ def get_weight_color(weight): primitives["unit_cell"].append(self.structure.lattice.get_scene()) + """ + ss = Scene( + name="StructureGraph", + origin=origin, + contents=[ + Scene(name=key, contents=val, origin=origin) + for key, val in primitives.items() + ], + ) + print(id(ss)) + print(ss.contents[1]) + print(ss.contents[1].contents[0]._meta) + print(ss) + + return(ss) + """ + # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -226,6 +245,7 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) + StructureGraph._get_sites_to_draw = _get_sites_to_draw From a0b9888a976070112f363500cec6a119783b293f Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 16:15:26 +0800 Subject: [PATCH 2/8] Revert "add atom index to _meta" This reverts commit 56497f21885ce9b8111aa0ad69be841ac6fa878d. --- crystal_toolkit/core/scene.py | 21 +++++++------------ crystal_toolkit/renderables/site.py | 5 ----- crystal_toolkit/renderables/structuregraph.py | 20 ------------------ 3 files changed, 8 insertions(+), 38 deletions(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 01b6f7f1..392c4d49 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -67,13 +67,13 @@ def __add__(self, other): lattice=self.lattice, _meta={self.name: self._meta, other.name: other._meta}, ) - - # def _repr_mimebundle_(self, include=None, exclude=None): - # """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" - # return { - # "application/vnd.mp.ctk+json": self.to_json(), - # "text/plain": repr(self), - # } + + def _repr_mimebundle_(self, include=None, exclude=None): + """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" + return { + "application/vnd.mp.ctk+json": self.to_json(), + "text/plain": repr(self), + } def to_json(self): """Convert a Scene into JSON. It will implicitly assume all None values means that attribute @@ -149,6 +149,7 @@ def merge_primitives(primitives): """ mergeable = defaultdict(list) remainder = [] + for primitive in primitives: if isinstance(primitive, Scene): primitive.contents = Scene.merge_primitives(primitive.contents) @@ -213,7 +214,6 @@ def merge(cls, sphere_list): visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, - _meta=sphere_list[0]._meta, ) @@ -320,10 +320,6 @@ def merge(cls, cylinder_list): chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) - new_meta_list = list( - chain.from_iterable([[cylinder._meta] for cylinder in cylinder_list]) - ) - return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, @@ -331,7 +327,6 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, - _meta=new_meta_list ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index d02aa95e..ef43be57 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,7 +135,6 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx] ) atoms.append(sphere) @@ -208,7 +207,6 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -220,7 +218,6 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) @@ -231,7 +228,6 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -255,7 +251,6 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index 5c7abb25..6c8077c2 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -197,8 +197,6 @@ def get_weight_color(weight): explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, bond_radius=bond_radius, - site_idx=idx, - show_atom_idx=True, **(site_get_scene_kwargs or {}), ) @@ -219,23 +217,6 @@ def get_weight_color(weight): primitives["unit_cell"].append(self.structure.lattice.get_scene()) - """ - ss = Scene( - name="StructureGraph", - origin=origin, - contents=[ - Scene(name=key, contents=val, origin=origin) - for key, val in primitives.items() - ], - ) - print(id(ss)) - print(ss.contents[1]) - print(ss.contents[1].contents[0]._meta) - print(ss) - - return(ss) - """ - # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -245,7 +226,6 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) - StructureGraph._get_sites_to_draw = _get_sites_to_draw From c024d477905138e19621536996a2350fe1d2d341 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 16:41:24 +0800 Subject: [PATCH 3/8] add index to _meta --- crystal_toolkit/core/scene.py | 7 +++++++ crystal_toolkit/renderables/site.py | 5 +++++ crystal_toolkit/renderables/structuregraph.py | 5 ++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 392c4d49..54ee3982 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -214,6 +214,7 @@ def merge(cls, sphere_list): visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, + _meta=sphere_list[0]._meta, ) @@ -272,6 +273,7 @@ def merge(cls, ellipsoid_list): ] ) ) + return cls( positions=new_positions, @@ -320,6 +322,10 @@ def merge(cls, cylinder_list): chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) + new_meta_list = list( + chain.from_iterable([[cylinder._meta] for cylinder in cylinder_list]) + ) + return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, @@ -327,6 +333,7 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, + _meta=new_meta_list ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index ef43be57..d02aa95e 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,6 +135,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, + _meta=[site_idx] ) atoms.append(sphere) @@ -207,6 +208,7 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -218,6 +220,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) @@ -228,6 +231,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -251,6 +255,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index 6c8077c2..b6315bc4 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -197,6 +197,8 @@ def get_weight_color(weight): explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, bond_radius=bond_radius, + site_idx=idx, + show_atom_idx=True, **(site_get_scene_kwargs or {}), ) @@ -216,7 +218,7 @@ def get_weight_color(weight): primitives["atoms"] = atoms_scenes primitives["unit_cell"].append(self.structure.lattice.get_scene()) - + # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -226,6 +228,7 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) + StructureGraph._get_sites_to_draw = _get_sites_to_draw From 52f15bfcb453e5631a41fb9e0cefe642c6c5aac1 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 17:00:33 +0800 Subject: [PATCH 4/8] remove empty line --- crystal_toolkit/renderables/structuregraph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index b6315bc4..163c7230 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -218,7 +218,7 @@ def get_weight_color(weight): primitives["atoms"] = atoms_scenes primitives["unit_cell"].append(self.structure.lattice.get_scene()) - + # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -228,7 +228,6 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) - StructureGraph._get_sites_to_draw = _get_sites_to_draw From 3d6266fecbb6f29f06d5105894bfe7585df2b726 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:26:33 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- crystal_toolkit/core/scene.py | 3 +-- crystal_toolkit/renderables/site.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 54ee3982..ad2d0621 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -273,7 +273,6 @@ def merge(cls, ellipsoid_list): ] ) ) - return cls( positions=new_positions, @@ -333,7 +332,7 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, - _meta=new_meta_list + _meta=new_meta_list, ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index d02aa95e..18f42128 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,7 +135,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx] + _meta=[site_idx], ) atoms.append(sphere) @@ -208,7 +208,7 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -220,7 +220,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) bonds.append(cylinder) @@ -231,7 +231,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -255,7 +255,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) From d16ac38036e7e60c8df08692d5f9c55d9a3689d3 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 18:09:47 +0800 Subject: [PATCH 6/8] ruff format --- crystal_toolkit/renderables/site.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index 18f42128..5ce92012 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -47,6 +47,7 @@ def get_site_scene( visualize_bond_orders: bool = False, magmom_scale: float = 1.0, legend: Legend | None = None, + retain_atom_idx: bool = False, ) -> Scene: """Get a Scene object for a Site. @@ -70,6 +71,7 @@ def get_site_scene( visualize_bond_orders (bool, optional): Defaults to False. magmom_scale (float, optional): Defaults to 1.0. legend (Legend | None, optional): Defaults to None. + retain_atom_idx (bool, optional): Defaults to False. Returns: Scene: The scene object containing atoms, bonds, polyhedra, magmoms. From 0e1c9249e81c9d51b094866c320e110f5ba11e99 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 18:19:08 +0800 Subject: [PATCH 7/8] add retain_atom_idx --- crystal_toolkit/renderables/site.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index 5ce92012..bab81bfd 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -137,7 +137,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx], + _meta=[site_idx] if retain_atom_idx else None, ) atoms.append(sphere) @@ -210,7 +210,9 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] + if retain_atom_idx + else None, ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -222,7 +224,9 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] + if retain_atom_idx + else None, ) bonds.append(cylinder) @@ -233,7 +237,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] if retain_atom_idx else None, ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -257,7 +261,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] if retain_atom_idx else None, ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) From a16c9c5d2947295425618f1419485cc0271bd0cb Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Tue, 9 Sep 2025 12:46:12 -0700 Subject: [PATCH 8/8] add animation component --- crystal_toolkit/components/phonon.py | 384 +++++++++++++++++---------- 1 file changed, 244 insertions(+), 140 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 296a2947..f0808a0a 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -8,7 +8,11 @@ from dash import dcc, html from dash.dependencies import Component, Input, Output from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitScene +from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene + +# crystal animation algo +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis.local_env import CrystalNN from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos @@ -17,14 +21,7 @@ from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import ( - Column, - Columns, - Label, - MessageBody, - MessageContainer, - get_data_list, -) +from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: @@ -64,26 +61,32 @@ def __init__( **kwargs, ) + bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + self.create_store("bs-store", bs) + self.create_store("bs", None) + self.create_store("dos", None) + @property def _sub_layouts(self) -> dict[str, Component]: # defaults state = {"label-select": "sc", "dos-select": "ap"} - bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - fig = PhononBandstructureAndDosComponent.get_figure(bs, dos) + fig = PhononBandstructureAndDosComponent.get_figure(None, None) # Main plot graph = dcc.Graph( figure=fig, config={"displayModeBar": False}, - responsive=True, + responsive=False, id=self.id("ph-bsdos-graph"), ) # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(bs) - zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px") + zone_scene = self.get_brillouin_zone_scene(None) + zone = CrystalToolkitScene( + data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") + ) # Hide by default if not loaded by mpid, switching between k-paths # on-the-fly only supported for bandstructures retrieved from MP @@ -138,9 +141,29 @@ def _sub_layouts(self) -> dict[str, Component]: style={"width": "200px"}, ) - summary_dict = self._get_data_list_dict(bs, dos) + summary_dict = self._get_data_list_dict(None, None) summary_table = get_data_list(summary_dict) + # crystal visualization + + tip = html.P( + "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", + id=self.id("crystal-tip"), + style={ + "margin": "0 0 12px", + "fontSize": "16px", + "color": "#555", + "textAlign": "center", + }, + ) + + crystal_animation = CrystalToolkitAnimationScene( + data={}, + sceneSize="200px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.5}, + ) + return { "graph": graph, "convention": convention, @@ -148,10 +171,15 @@ def _sub_layouts(self) -> dict[str, Component]: "label-select": label_select, "zone": zone, "table": summary_table, + "crystal-animation": crystal_animation, + "tip": tip, } def layout(self) -> html.Div: sub_layouts = self._sub_layouts + crystal_animation = Columns( + [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + ) graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -166,11 +194,147 @@ def layout(self) -> html.Div: ) brillouin_zone = Columns( [ - Column([Label("Summary"), sub_layouts["table"]]), + Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), Column([Label("Brillouin Zone"), sub_layouts["zone"]]), ] ) - return html.Div([graph, controls, brillouin_zone]) + + return html.Div([graph, crystal_animation, controls, brillouin_zone]) + + @staticmethod + def _get_eigendisplacement( + ph_bs: BandStructureSymmLine, + json_data: dict, + band: int = 0, + qpoint: int = 0, + precision: int = 15, + magnitude: int = 15, + ) -> dict: + if not ph_bs: + return {} + + # get displacement + min_bond_length = float("inf") + for content_idx in range(len(json_data["contents"][1]["contents"])): + for pair_idx in range( + len(json_data["contents"][1]["contents"][content_idx]["_meta"]) + ): + u, v = json_data["contents"][1]["contents"][content_idx][ + "positionPairs" + ][pair_idx] + # Convert to numpy arrays + u = np.array(u) + v = np.array(v) + length = np.linalg.norm(v - u) + min_bond_length = min(min_bond_length, length) + + # atom animate + assert json_data["contents"][0]["name"] == "atoms" + for content_idx in range(len(json_data["contents"][0]["contents"])): + atom_idx = json_data["contents"][0]["contents"][content_idx]["_meta"][0] + + raw_displacement = ph_bs.eigendisplacements[band][qpoint][atom_idx] + + displacement = [complex(vec).real * magnitude for vec in raw_displacement] + + position_animation = [] + for displace_coef in [0, 1, 0, -1, 0]: + displace = [ + round(displace_coef * magnitude * d, precision) + for d in displacement + ] + position_animation.append(displace) + + json_data["contents"][0]["contents"][content_idx]["animate"] = ( + position_animation + ) + json_data["contents"][0]["contents"][content_idx]["keyframes"] = [ + 0, + 1, + 2, + 3, + 4, + ] + json_data["contents"][0]["contents"][content_idx]["animateType"] = ( + "displacement" + ) + + # bond animate + assert json_data["contents"][1]["name"] == "bonds" + for content_idx in range(len(json_data["contents"][1]["contents"])): + bond_animation = [] + + assert len( + json_data["contents"][1]["contents"][content_idx]["_meta"] + ) == len(json_data["contents"][1]["contents"][content_idx]["positionPairs"]) + + for pair_idx in range( + len(json_data["contents"][1]["contents"][content_idx]["_meta"]) + ): + u_idx, v_idx = json_data["contents"][1]["contents"][content_idx][ + "_meta" + ][pair_idx] + + # u + u_raw_displacement = ph_bs.eigendisplacements[band][qpoint][u_idx] + u_displacement = [ + round(complex(vec).real * magnitude, precision) + for vec in u_raw_displacement + ] + + # v + v_raw_displacement = ph_bs.eigendisplacements[band][qpoint][v_idx] + v_displacement = [ + round(complex(vec).real * magnitude, precision) + for vec in v_raw_displacement + ] + + # only draw in unit cell + u_to_middle_bond_animation = [] # u to middle + # v_to_middle_bond_animation = [] # v to middle + for displace_coef in [0, 1, 0, -1, 0]: + u_end_displacement = [ + round(displace_coef * magnitude * d, precision) + for d in u_displacement + ] + v_end_displacement = [ + round(displace_coef * magnitude * d, precision) + for d in v_displacement + ] + middle_end_displacement = ( + (np.array(u_end_displacement) + np.array(v_end_displacement)) + / 2 + ).tolist() + middle_end_displacement = [ + round(dis, precision) for dis in middle_end_displacement + ] + + u2middle_animation = [u_end_displacement, middle_end_displacement] + # v2middle_animation = [v_end_displacement, middle_end_displacement] + + u_to_middle_bond_animation.append(u2middle_animation) + # v_to_middle_bond_animation.append(v2middle_animation) + + bond_animation.append(u_to_middle_bond_animation) + json_data["contents"][1]["contents"][content_idx]["animate"] = ( + bond_animation + ) + json_data["contents"][1]["contents"][content_idx]["keyframes"] = [ + 0, + 1, + 2, + 3, + 4, + ] + json_data["contents"][1]["contents"][content_idx]["animateType"] = ( + "displacement" + ) + + # remove polyhedra manually + json_data["contents"][2]["visible"] = False + json_data["contents"][3]["visible"] = False + + return json_data @staticmethod def _get_ph_bs_dos( @@ -303,6 +467,7 @@ def get_ph_bandstructure_traces(bs, freq_range): "line": {"color": "#1f77b4"}, "hoverinfo": "skip", "name": "Total", + "customdata": [[di, band_num] for di in range(len(x_dat))], "hovertemplate": "%{y:.2f} THz", "showlegend": False, "xaxis": "x", @@ -348,6 +513,9 @@ def get_ph_bandstructure_traces(bs, freq_range): def _get_data_list_dict( bs: PhononBandStructureSymmLine, dos: CompletePhononDos ) -> dict[str, str | bool | int]: + if (not bs) and (not dos): + return {} + bs_minpoint, bs_min_freq = bs.min_freq() min_freq_report = ( f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" @@ -443,14 +611,9 @@ def get_figure( ph_dos: CompletePhononDos | None = None, freq_range: tuple[float | None, float | None] = (None, None), ) -> go.Figure: - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - if (not ph_dos) and (not ph_bs): empty_plot_style = { + "height": 500, "xaxis": {"visible": False}, "yaxis": {"visible": False}, "paper_bgcolor": "rgba(0,0,0,0)", @@ -459,6 +622,12 @@ def get_figure( return go.Figure(layout=empty_plot_style) + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + if ph_bs: ( bs_traces, @@ -555,7 +724,7 @@ def get_figure( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(230,230,230,230)", margin=dict(l=60, b=50, t=50, pad=0, r=30), - # clickmode="event+select" + clickmode="event+select", ) figure = {"data": bs_traces + dos_traces, "layout": layout} @@ -580,124 +749,25 @@ def get_figure( def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), - Input(self.id("traces"), "data"), + Output(self.id("zone"), "data"), + Output(self.id("table"), "children"), + Input(self.id("ph_bs"), "data"), + Input(self.id("ph_dos"), "data"), ) - def update_graph(traces): - if traces == "error": - msg_body = MessageBody( - dcc.Markdown( - "Band structure and density of states not available for this selection." - ) - ) - return (MessageContainer([msg_body], kind="warning"),) - - if traces is None: - raise PreventUpdate - - bs, dos = self._get_ph_bs_dos(self.initial_data["default"]) + def update_graph(bs, dos): + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + if isinstance(dos, dict): + dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - return dcc.Graph( - figure=figure, config={"displayModeBar": False}, responsive=True - ) - - @app.callback( - Output(self.id("label-select"), "value"), - Output(self.id("label-container"), "style"), - Input(self.id("mpid"), "data"), - Input(self.id("path-convention"), "value"), - ) - def update_label_select(mpid, path_convention): - if not mpid: - raise PreventUpdate - label_value = path_convention - label_style = {"maxWidth": "200"} - - return label_value, label_style - - @app.callback( - Output(self.id("dos-select"), "options"), - Output(self.id("path-convention"), "options"), - Output(self.id("path-container"), "style"), - Input(self.id("elements"), "data"), - Input(self.id("mpid"), "data"), - ) - def update_select(elements, mpid): - if elements is None: - raise PreventUpdate - if not mpid: - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [{"label": "N/A", "value": "sc"}] - path_style = {"maxWidth": "200", "display": "none"} - - return dos_options, path_options, path_style - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [ - {"label": "Setyawan-Curtarolo", "value": "sc"}, - {"label": "Latimer-Munro", "value": "lm"}, - {"label": "Hinuma et al.", "value": "hin"}, - ] - path_style = {"maxWidth": "200"} + zone_scene = self.get_brillouin_zone_scene(bs) - return dos_options, path_options, path_style + summary_dict = self._get_data_list_dict(bs, dos) + summary_table = get_data_list(summary_dict) - @app.callback( - Output(self.id("traces"), "data"), - Output(self.id("elements"), "data"), - Input(self.id(), "data"), - Input(self.id("path-convention"), "value"), - Input(self.id("dos-select"), "value"), - Input(self.id("label-select"), "value"), - ) - def bs_dos_data(data, dos_select, label_select): - # Obtain bands to plot over and generate traces for bs data: - energy_window = (-6.0, 10.0) - - traces = [] - - bsml, density_of_states = self._get_ph_bs_dos(data) - - if self.bandstructure_symm_line: - bs_traces = self.get_ph_bandstructure_traces( - bsml, freq_range=energy_window - ) - traces.append(bs_traces) - - if self.density_of_states: - dos_traces = self.get_ph_dos_traces( - density_of_states, freq_range=energy_window - ) - traces.append(dos_traces) - - # traces = [bs_traces, dos_traces, bs_data] - - # TODO: not tested if this is correct way to get element list - elements = list(map(str, density_of_states.get_element_dos())) - - return traces, elements + return figure, zone_scene.to_json(), summary_table @app.callback( Output(self.id("brillouin-zone"), "data"), @@ -711,8 +781,42 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): # TODO: figure out what to return (CSS?) to highlight BZ edge/point return - # TODO: figure out what to return (CSS?) to highlight BZ edge/point - return + @app.callback( + Output(self.id("crystal-animation"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("ph_bs"), "data"), + # prevent_initial_call=True + ) + def update_crystal_animation(cd, bs): + if not bs: + raise PreventUpdate + + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + + struc_graph = StructureGraph.from_local_env_strategy( + bs.structure, CrystalNN() + ) + scene = struc_graph.get_scene( + draw_image_atoms=False, + bonded_sites_outside_unit_cell=False, + site_get_scene_kwargs={"retain_atom_idx": True}, + ) + json_data = scene.to_json() + + qpoint = 0 + band_num = 0 + + if cd and cd.get("points"): + pt = cd["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + return PhononBandstructureAndDosComponent._get_eigendisplacement( + ph_bs=bs, + json_data=json_data, + band=band_num, + qpoint=qpoint, + ) class PhononBandstructureAndDosPanelComponent(PanelComponent):