11from __future__ import annotations
22
33import itertools
4- from copy import deepcopy
54from typing import TYPE_CHECKING , Any
65
76import numpy as np
87import plotly .graph_objects as go
98from dash import dcc , html
109from dash .dependencies import Component , Input , Output , State
1110from dash .exceptions import PreventUpdate
12- from dash_mp_components import CrystalToolkitAnimationScene , CrystalToolkitScene
13-
14- # crystal animation algo
15- from pymatgen .analysis .graphs import StructureGraph
16- from pymatgen .analysis .local_env import CrystalNN
11+ from dash_mp_components import CrystalToolkitScene
1712from pymatgen .ext .matproj import MPRester
1813from pymatgen .phonon .bandstructure import PhononBandStructureSymmLine
1914from pymatgen .phonon .dos import CompletePhononDos
2318from crystal_toolkit .core .mpcomponent import MPComponent
2419from crystal_toolkit .core .panelcomponent import PanelComponent
2520from crystal_toolkit .core .scene import Convex , Cylinders , Lines , Scene , Spheres
26- from crystal_toolkit .helpers .layouts import Column , Columns , Label , get_data_list
21+ from crystal_toolkit .helpers .layouts import (
22+ Column ,
23+ Columns ,
24+ Label ,
25+ MessageBody ,
26+ MessageContainer ,
27+ get_data_list ,
28+ )
2729from crystal_toolkit .helpers .pretty_labels import pretty_labels
2830
2931if TYPE_CHECKING :
@@ -66,32 +68,26 @@ def __init__(
6668 ** kwargs ,
6769 )
6870
69- bs , _ = PhononBandstructureAndDosComponent ._get_ph_bs_dos (
70- self .initial_data ["default" ]
71- )
72- self .create_store ("bs-store" , bs )
73- self .create_store ("bs" , None )
74- self .create_store ("dos" , None )
75-
7671 @property
7772 def _sub_layouts (self ) -> dict [str , Component ]:
7873 # defaults
7974 state = {"label-select" : "sc" , "dos-select" : "ap" }
8075
81- fig = PhononBandstructureAndDosComponent .get_figure (None , None )
76+ bs , dos = PhononBandstructureAndDosComponent ._get_ph_bs_dos (
77+ self .initial_data ["default" ]
78+ )
79+ fig = PhononBandstructureAndDosComponent .get_figure (bs , dos )
8280 # Main plot
8381 graph = dcc .Graph (
8482 figure = fig ,
8583 config = {"displayModeBar" : False },
86- responsive = False ,
84+ responsive = True ,
8785 id = self .id ("ph-bsdos-graph" ),
8886 )
8987
9088 # Brillouin zone
91- zone_scene = self .get_brillouin_zone_scene (None )
92- zone = CrystalToolkitScene (
93- data = zone_scene .to_json (), sceneSize = "500px" , id = self .id ("zone" )
94- )
89+ zone_scene = self .get_brillouin_zone_scene (bs )
90+ zone = CrystalToolkitScene (data = zone_scene .to_json (), sceneSize = "500px" )
9591
9692 # Hide by default if not loaded by mpid, switching between k-paths
9793 # on-the-fly only supported for bandstructures retrieved from MP
@@ -113,11 +109,9 @@ def _sub_layouts(self) -> dict[str, Component]:
113109 options = options ,
114110 )
115111 ],
116- style = (
117- {"width" : "200px" }
118- if show_path_options
119- else {"maxWidth" : "200" , "display" : "none" }
120- ),
112+ style = {"width" : "200px" }
113+ if show_path_options
114+ else {"maxWidth" : "200" , "display" : "none" },
121115 id = self .id ("path-container" ),
122116 )
123117
@@ -132,11 +126,9 @@ def _sub_layouts(self) -> dict[str, Component]:
132126 options = options ,
133127 )
134128 ],
135- style = (
136- {"width" : "200px" }
137- if show_path_options
138- else {"width" : "200px" , "display" : "none" }
139- ),
129+ style = {"width" : "200px" }
130+ if show_path_options
131+ else {"width" : "200px" , "display" : "none" },
140132 id = self .id ("label-container" ),
141133 )
142134
@@ -150,7 +142,7 @@ def _sub_layouts(self) -> dict[str, Component]:
150142 style = {"width" : "200px" },
151143 )
152144
153- summary_dict = self ._get_data_list_dict (None , None )
145+ summary_dict = self ._get_data_list_dict (bs , dos )
154146 summary_table = get_data_list (summary_dict )
155147
156148 # crystal visualization
@@ -272,7 +264,7 @@ def layout(self) -> html.Div:
272264 )
273265 brillouin_zone = Columns (
274266 [
275- Column ([Label ("Summary" ), sub_layouts ["table" ]], id = self . id ( "table" ) ),
267+ Column ([Label ("Summary" ), sub_layouts ["table" ]]),
276268 Column ([Label ("Brillouin Zone" ), sub_layouts ["zone" ]]),
277269 ]
278270 )
@@ -541,7 +533,6 @@ def get_ph_bandstructure_traces(bs, freq_range):
541533 "line" : {"color" : "#1f77b4" },
542534 "hoverinfo" : "skip" ,
543535 "name" : "Total" ,
544- "customdata" : [[di , band_num ] for di in range (len (x_dat ))],
545536 "hovertemplate" : "%{y:.2f} THz" ,
546537 "showlegend" : False ,
547538 "xaxis" : "x" ,
@@ -587,9 +578,6 @@ def get_ph_bandstructure_traces(bs, freq_range):
587578 def _get_data_list_dict (
588579 bs : PhononBandStructureSymmLine , dos : CompletePhononDos
589580 ) -> dict [str , str | bool | int ]:
590- if (not bs ) and (not dos ):
591- return {}
592-
593581 bs_minpoint , bs_min_freq = bs .min_freq ()
594582 min_freq_report = (
595583 f"{ bs_min_freq :.2f} THz at frac. coords. { bs_minpoint .frac_coords } "
@@ -615,7 +603,7 @@ def _get_data_list_dict(
615603 target = "blank" ,
616604 ),
617605 ]
618- ): ( "Yes" if bs .has_nac else "No" ) ,
606+ ): "Yes" if bs .has_nac else "No" ,
619607 "Has imaginary frequencies" : "Yes" if bs .has_imaginary_freq () else "No" ,
620608 "Has eigen-displacements" : "Yes" if bs .has_eigendisplacements else "No" ,
621609 "Min frequency" : min_freq_report ,
@@ -685,9 +673,14 @@ def get_figure(
685673 ph_dos : CompletePhononDos | None = None ,
686674 freq_range : tuple [float | None , float | None ] = (None , None ),
687675 ) -> go .Figure :
676+ if freq_range [0 ] is None :
677+ freq_range = (np .min (ph_bs .bands ) * 1.05 , freq_range [1 ])
678+
679+ if freq_range [1 ] is None :
680+ freq_range = (freq_range [0 ], np .max (ph_bs .bands ) * 1.05 )
681+
688682 if (not ph_dos ) and (not ph_bs ):
689683 empty_plot_style = {
690- "height" : 500 ,
691684 "xaxis" : {"visible" : False },
692685 "yaxis" : {"visible" : False },
693686 "paper_bgcolor" : "rgba(0,0,0,0)" ,
@@ -696,12 +689,6 @@ def get_figure(
696689
697690 return go .Figure (layout = empty_plot_style )
698691
699- if freq_range [0 ] is None :
700- freq_range = (np .min (ph_bs .bands ) * 1.05 , freq_range [1 ])
701-
702- if freq_range [1 ] is None :
703- freq_range = (freq_range [0 ], np .max (ph_bs .bands ) * 1.05 )
704-
705692 if ph_bs :
706693 (
707694 bs_traces ,
@@ -798,7 +785,7 @@ def get_figure(
798785 paper_bgcolor = "rgba(0,0,0,0)" ,
799786 plot_bgcolor = "rgba(230,230,230,230)" ,
800787 margin = dict (l = 60 , b = 50 , t = 50 , pad = 0 , r = 30 ),
801- clickmode = "event+select" ,
788+ # clickmode="event+select"
802789 )
803790
804791 figure = {"data" : bs_traces + dos_traces , "layout" : layout }
@@ -836,6 +823,9 @@ def update_graph(bs, dos, nclick):
836823 dos = CompletePhononDos .from_dict (dos )
837824
838825 figure = self .get_figure (bs , dos )
826+ return dcc .Graph (
827+ figure = figure , config = {"displayModeBar" : False }, responsive = True
828+ )
839829
840830 # remove marker if there is one
841831 figure ["data" ] = [
@@ -870,10 +860,91 @@ def update_graph(bs, dos, nclick):
870860
871861 zone_scene = self .get_brillouin_zone_scene (bs )
872862
873- summary_dict = self ._get_data_list_dict (bs , dos )
874- summary_table = get_data_list (summary_dict )
863+ return label_value , label_style
864+
865+ @app .callback (
866+ Output (self .id ("dos-select" ), "options" ),
867+ Output (self .id ("path-convention" ), "options" ),
868+ Output (self .id ("path-container" ), "style" ),
869+ Input (self .id ("elements" ), "data" ),
870+ Input (self .id ("mpid" ), "data" ),
871+ )
872+ def update_select (elements , mpid ):
873+ if elements is None :
874+ raise PreventUpdate
875+ if not mpid :
876+ dos_options = (
877+ [{"label" : "Element Projected" , "value" : "ap" }]
878+ + [{"label" : "Orbital Projected - Total" , "value" : "op" }]
879+ + [
880+ {
881+ "label" : "Orbital Projected - " + str (ele_label ),
882+ "value" : "orb" + str (ele_label ),
883+ }
884+ for ele_label in elements
885+ ]
886+ )
887+
888+ path_options = [{"label" : "N/A" , "value" : "sc" }]
889+ path_style = {"maxWidth" : "200" , "display" : "none" }
890+
891+ return dos_options , path_options , path_style
892+ dos_options = (
893+ [{"label" : "Element Projected" , "value" : "ap" }]
894+ + [{"label" : "Orbital Projected - Total" , "value" : "op" }]
895+ + [
896+ {
897+ "label" : "Orbital Projected - " + str (ele_label ),
898+ "value" : "orb" + str (ele_label ),
899+ }
900+ for ele_label in elements
901+ ]
902+ )
903+
904+ path_options = [
905+ {"label" : "Setyawan-Curtarolo" , "value" : "sc" },
906+ {"label" : "Latimer-Munro" , "value" : "lm" },
907+ {"label" : "Hinuma et al." , "value" : "hin" },
908+ ]
909+
910+ path_style = {"maxWidth" : "200" }
911+
912+ return dos_options , path_options , path_style
913+
914+ @app .callback (
915+ Output (self .id ("traces" ), "data" ),
916+ Output (self .id ("elements" ), "data" ),
917+ Input (self .id (), "data" ),
918+ Input (self .id ("path-convention" ), "value" ),
919+ Input (self .id ("dos-select" ), "value" ),
920+ Input (self .id ("label-select" ), "value" ),
921+ )
922+ def bs_dos_data (data , dos_select , label_select ):
923+ # Obtain bands to plot over and generate traces for bs data:
924+ energy_window = (- 6.0 , 10.0 )
925+
926+ traces = []
927+
928+ bsml , density_of_states = self ._get_ph_bs_dos (data )
929+
930+ if self .bandstructure_symm_line :
931+ bs_traces = self .get_ph_bandstructure_traces (
932+ bsml , freq_range = energy_window
933+ )
934+ traces .append (bs_traces )
935+
936+ if self .density_of_states :
937+ dos_traces = self .get_ph_dos_traces (
938+ density_of_states , freq_range = energy_window
939+ )
940+ traces .append (dos_traces )
941+
942+ # traces = [bs_traces, dos_traces, bs_data]
943+
944+ # TODO: not tested if this is correct way to get element list
945+ elements = list (map (str , density_of_states .get_element_dos ()))
875946
876- return figure , zone_scene . to_json (), summary_table
947+ return traces , elements
877948
878949 @app .callback (
879950 Output (self .id ("brillouin-zone" ), "data" ),
0 commit comments