diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 296a2947..f0808a0a 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -8,7 +8,11 @@ from dash import dcc, html from dash.dependencies import Component, Input, Output from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitScene +from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene + +# crystal animation algo +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis.local_env import CrystalNN from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos @@ -17,14 +21,7 @@ from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import ( - Column, - Columns, - Label, - MessageBody, - MessageContainer, - get_data_list, -) +from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: @@ -64,26 +61,32 @@ def __init__( **kwargs, ) + bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + self.create_store("bs-store", bs) + self.create_store("bs", None) + self.create_store("dos", None) + @property def _sub_layouts(self) -> dict[str, Component]: # defaults state = {"label-select": "sc", "dos-select": "ap"} - bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - fig = PhononBandstructureAndDosComponent.get_figure(bs, dos) + fig = PhononBandstructureAndDosComponent.get_figure(None, None) # Main plot graph = dcc.Graph( figure=fig, config={"displayModeBar": False}, - responsive=True, + responsive=False, id=self.id("ph-bsdos-graph"), ) # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(bs) - zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px") + zone_scene = self.get_brillouin_zone_scene(None) + zone = CrystalToolkitScene( + data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") + ) # Hide by default if not loaded by mpid, switching between k-paths # on-the-fly only supported for bandstructures retrieved from MP @@ -138,9 +141,29 @@ def _sub_layouts(self) -> dict[str, Component]: style={"width": "200px"}, ) - summary_dict = self._get_data_list_dict(bs, dos) + summary_dict = self._get_data_list_dict(None, None) summary_table = get_data_list(summary_dict) + # crystal visualization + + tip = html.P( + "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", + id=self.id("crystal-tip"), + style={ + "margin": "0 0 12px", + "fontSize": "16px", + "color": "#555", + "textAlign": "center", + }, + ) + + crystal_animation = CrystalToolkitAnimationScene( + data={}, + sceneSize="200px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.5}, + ) + return { "graph": graph, "convention": convention, @@ -148,10 +171,15 @@ def _sub_layouts(self) -> dict[str, Component]: "label-select": label_select, "zone": zone, "table": summary_table, + "crystal-animation": crystal_animation, + "tip": tip, } def layout(self) -> html.Div: sub_layouts = self._sub_layouts + crystal_animation = Columns( + [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + ) graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -166,11 +194,147 @@ def layout(self) -> html.Div: ) brillouin_zone = Columns( [ - Column([Label("Summary"), sub_layouts["table"]]), + Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), Column([Label("Brillouin Zone"), sub_layouts["zone"]]), ] ) - return html.Div([graph, controls, brillouin_zone]) + + return html.Div([graph, crystal_animation, controls, brillouin_zone]) + + @staticmethod + def _get_eigendisplacement( + ph_bs: BandStructureSymmLine, + json_data: dict, + band: int = 0, + qpoint: int = 0, + precision: int = 15, + magnitude: int = 15, + ) -> dict: + if not ph_bs: + return {} + + # get displacement + min_bond_length = float("inf") + for content_idx in range(len(json_data["contents"][1]["contents"])): + for pair_idx in range( + len(json_data["contents"][1]["contents"][content_idx]["_meta"]) + ): + u, v = json_data["contents"][1]["contents"][content_idx][ + "positionPairs" + ][pair_idx] + # Convert to numpy arrays + u = np.array(u) + v = np.array(v) + length = np.linalg.norm(v - u) + min_bond_length = min(min_bond_length, length) + + # atom animate + assert json_data["contents"][0]["name"] == "atoms" + for content_idx in range(len(json_data["contents"][0]["contents"])): + atom_idx = json_data["contents"][0]["contents"][content_idx]["_meta"][0] + + raw_displacement = ph_bs.eigendisplacements[band][qpoint][atom_idx] + + displacement = [complex(vec).real * magnitude for vec in raw_displacement] + + position_animation = [] + for displace_coef in [0, 1, 0, -1, 0]: + displace = [ + round(displace_coef * magnitude * d, precision) + for d in displacement + ] + position_animation.append(displace) + + json_data["contents"][0]["contents"][content_idx]["animate"] = ( + position_animation + ) + json_data["contents"][0]["contents"][content_idx]["keyframes"] = [ + 0, + 1, + 2, + 3, + 4, + ] + json_data["contents"][0]["contents"][content_idx]["animateType"] = ( + "displacement" + ) + + # bond animate + assert json_data["contents"][1]["name"] == "bonds" + for content_idx in range(len(json_data["contents"][1]["contents"])): + bond_animation = [] + + assert len( + json_data["contents"][1]["contents"][content_idx]["_meta"] + ) == len(json_data["contents"][1]["contents"][content_idx]["positionPairs"]) + + for pair_idx in range( + len(json_data["contents"][1]["contents"][content_idx]["_meta"]) + ): + u_idx, v_idx = json_data["contents"][1]["contents"][content_idx][ + "_meta" + ][pair_idx] + + # u + u_raw_displacement = ph_bs.eigendisplacements[band][qpoint][u_idx] + u_displacement = [ + round(complex(vec).real * magnitude, precision) + for vec in u_raw_displacement + ] + + # v + v_raw_displacement = ph_bs.eigendisplacements[band][qpoint][v_idx] + v_displacement = [ + round(complex(vec).real * magnitude, precision) + for vec in v_raw_displacement + ] + + # only draw in unit cell + u_to_middle_bond_animation = [] # u to middle + # v_to_middle_bond_animation = [] # v to middle + for displace_coef in [0, 1, 0, -1, 0]: + u_end_displacement = [ + round(displace_coef * magnitude * d, precision) + for d in u_displacement + ] + v_end_displacement = [ + round(displace_coef * magnitude * d, precision) + for d in v_displacement + ] + middle_end_displacement = ( + (np.array(u_end_displacement) + np.array(v_end_displacement)) + / 2 + ).tolist() + middle_end_displacement = [ + round(dis, precision) for dis in middle_end_displacement + ] + + u2middle_animation = [u_end_displacement, middle_end_displacement] + # v2middle_animation = [v_end_displacement, middle_end_displacement] + + u_to_middle_bond_animation.append(u2middle_animation) + # v_to_middle_bond_animation.append(v2middle_animation) + + bond_animation.append(u_to_middle_bond_animation) + json_data["contents"][1]["contents"][content_idx]["animate"] = ( + bond_animation + ) + json_data["contents"][1]["contents"][content_idx]["keyframes"] = [ + 0, + 1, + 2, + 3, + 4, + ] + json_data["contents"][1]["contents"][content_idx]["animateType"] = ( + "displacement" + ) + + # remove polyhedra manually + json_data["contents"][2]["visible"] = False + json_data["contents"][3]["visible"] = False + + return json_data @staticmethod def _get_ph_bs_dos( @@ -303,6 +467,7 @@ def get_ph_bandstructure_traces(bs, freq_range): "line": {"color": "#1f77b4"}, "hoverinfo": "skip", "name": "Total", + "customdata": [[di, band_num] for di in range(len(x_dat))], "hovertemplate": "%{y:.2f} THz", "showlegend": False, "xaxis": "x", @@ -348,6 +513,9 @@ def get_ph_bandstructure_traces(bs, freq_range): def _get_data_list_dict( bs: PhononBandStructureSymmLine, dos: CompletePhononDos ) -> dict[str, str | bool | int]: + if (not bs) and (not dos): + return {} + bs_minpoint, bs_min_freq = bs.min_freq() min_freq_report = ( f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" @@ -443,14 +611,9 @@ def get_figure( ph_dos: CompletePhononDos | None = None, freq_range: tuple[float | None, float | None] = (None, None), ) -> go.Figure: - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - if (not ph_dos) and (not ph_bs): empty_plot_style = { + "height": 500, "xaxis": {"visible": False}, "yaxis": {"visible": False}, "paper_bgcolor": "rgba(0,0,0,0)", @@ -459,6 +622,12 @@ def get_figure( return go.Figure(layout=empty_plot_style) + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + if ph_bs: ( bs_traces, @@ -555,7 +724,7 @@ def get_figure( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(230,230,230,230)", margin=dict(l=60, b=50, t=50, pad=0, r=30), - # clickmode="event+select" + clickmode="event+select", ) figure = {"data": bs_traces + dos_traces, "layout": layout} @@ -580,124 +749,25 @@ def get_figure( def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), - Input(self.id("traces"), "data"), + Output(self.id("zone"), "data"), + Output(self.id("table"), "children"), + Input(self.id("ph_bs"), "data"), + Input(self.id("ph_dos"), "data"), ) - def update_graph(traces): - if traces == "error": - msg_body = MessageBody( - dcc.Markdown( - "Band structure and density of states not available for this selection." - ) - ) - return (MessageContainer([msg_body], kind="warning"),) - - if traces is None: - raise PreventUpdate - - bs, dos = self._get_ph_bs_dos(self.initial_data["default"]) + def update_graph(bs, dos): + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + if isinstance(dos, dict): + dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - return dcc.Graph( - figure=figure, config={"displayModeBar": False}, responsive=True - ) - - @app.callback( - Output(self.id("label-select"), "value"), - Output(self.id("label-container"), "style"), - Input(self.id("mpid"), "data"), - Input(self.id("path-convention"), "value"), - ) - def update_label_select(mpid, path_convention): - if not mpid: - raise PreventUpdate - label_value = path_convention - label_style = {"maxWidth": "200"} - - return label_value, label_style - - @app.callback( - Output(self.id("dos-select"), "options"), - Output(self.id("path-convention"), "options"), - Output(self.id("path-container"), "style"), - Input(self.id("elements"), "data"), - Input(self.id("mpid"), "data"), - ) - def update_select(elements, mpid): - if elements is None: - raise PreventUpdate - if not mpid: - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [{"label": "N/A", "value": "sc"}] - path_style = {"maxWidth": "200", "display": "none"} - - return dos_options, path_options, path_style - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [ - {"label": "Setyawan-Curtarolo", "value": "sc"}, - {"label": "Latimer-Munro", "value": "lm"}, - {"label": "Hinuma et al.", "value": "hin"}, - ] - path_style = {"maxWidth": "200"} + zone_scene = self.get_brillouin_zone_scene(bs) - return dos_options, path_options, path_style + summary_dict = self._get_data_list_dict(bs, dos) + summary_table = get_data_list(summary_dict) - @app.callback( - Output(self.id("traces"), "data"), - Output(self.id("elements"), "data"), - Input(self.id(), "data"), - Input(self.id("path-convention"), "value"), - Input(self.id("dos-select"), "value"), - Input(self.id("label-select"), "value"), - ) - def bs_dos_data(data, dos_select, label_select): - # Obtain bands to plot over and generate traces for bs data: - energy_window = (-6.0, 10.0) - - traces = [] - - bsml, density_of_states = self._get_ph_bs_dos(data) - - if self.bandstructure_symm_line: - bs_traces = self.get_ph_bandstructure_traces( - bsml, freq_range=energy_window - ) - traces.append(bs_traces) - - if self.density_of_states: - dos_traces = self.get_ph_dos_traces( - density_of_states, freq_range=energy_window - ) - traces.append(dos_traces) - - # traces = [bs_traces, dos_traces, bs_data] - - # TODO: not tested if this is correct way to get element list - elements = list(map(str, density_of_states.get_element_dos())) - - return traces, elements + return figure, zone_scene.to_json(), summary_table @app.callback( Output(self.id("brillouin-zone"), "data"), @@ -711,8 +781,42 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): # TODO: figure out what to return (CSS?) to highlight BZ edge/point return - # TODO: figure out what to return (CSS?) to highlight BZ edge/point - return + @app.callback( + Output(self.id("crystal-animation"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("ph_bs"), "data"), + # prevent_initial_call=True + ) + def update_crystal_animation(cd, bs): + if not bs: + raise PreventUpdate + + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + + struc_graph = StructureGraph.from_local_env_strategy( + bs.structure, CrystalNN() + ) + scene = struc_graph.get_scene( + draw_image_atoms=False, + bonded_sites_outside_unit_cell=False, + site_get_scene_kwargs={"retain_atom_idx": True}, + ) + json_data = scene.to_json() + + qpoint = 0 + band_num = 0 + + if cd and cd.get("points"): + pt = cd["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + return PhononBandstructureAndDosComponent._get_eigendisplacement( + ph_bs=bs, + json_data=json_data, + band=band_num, + qpoint=qpoint, + ) class PhononBandstructureAndDosPanelComponent(PanelComponent):