77import numpy as np
88import plotly .graph_objects as go
99from dash import dcc , html
10- from dash .dependencies import Component , Input , Output
10+ from dash .dependencies import Component , Input , Output , State
1111from dash .exceptions import PreventUpdate
1212from dash_mp_components import CrystalToolkitAnimationScene , CrystalToolkitScene
1313
3434MARKER_COLOR = "red"
3535MARKER_SIZE = 12
3636MARKER_SHAPE = "x"
37- MAX_MAGNITUDE = 400
37+ MAX_MAGNITUDE = 300
3838MIN_MAGNITUDE = 0
3939
4040# TODOs:
@@ -165,6 +165,8 @@ def _sub_layouts(self) -> dict[str, Component]:
165165 sceneSize = "500px" ,
166166 id = self .id ("crystal-animation" ),
167167 settings = {"defaultZoom" : 1.2 },
168+ axisView = "SW" ,
169+ showControls = False , # disable download for now
168170 ),
169171 style = {"width" : "60%" },
170172 )
@@ -184,25 +186,28 @@ def _sub_layouts(self) -> dict[str, Component]:
184186 default = 1 ,
185187 is_int = True ,
186188 label = "x" ,
189+ min = 1 ,
187190 style = {"width" : "5rem" },
188191 ),
189192 self .get_numerical_input (
190193 kwarg_label = "scale-y" ,
191194 default = 1 ,
192195 is_int = True ,
193196 label = "y" ,
197+ min = 1 ,
194198 style = {"width" : "5rem" },
195199 ),
196200 self .get_numerical_input (
197201 kwarg_label = "scale-z" ,
198202 default = 1 ,
199203 is_int = True ,
200204 label = "z" ,
205+ min = 1 ,
201206 style = {"width" : "5rem" },
202207 ),
203208 html .Button (
204209 "Update" ,
205- id = self .id ("controls-btn" ),
210+ id = self .id ("supercell- controls-btn" ),
206211 style = {"height" : "40px" },
207212 ),
208213 ],
@@ -233,13 +238,12 @@ def _sub_layouts(self) -> dict[str, Component]:
233238 "crystal-animation-controls" : crystal_animation_controls ,
234239 }
235240
236- def layout (self ) -> html . Div :
241+ def _get_animation_panel (self ):
237242 sub_layouts = self ._sub_layouts
238- crystal_animation = Columns (
243+ return Columns (
239244 [
240245 Column (
241246 [
242- # sub_layouts["tip"],
243247 Columns (
244248 [
245249 sub_layouts ["crystal-animation" ],
@@ -248,9 +252,12 @@ def layout(self) -> html.Div:
248252 )
249253 ]
250254 ),
251- # Column([sub_layouts["crystal-animation-controls"]])
252255 ]
253256 )
257+
258+ def layout (self ) -> html .Div :
259+ sub_layouts = self ._sub_layouts
260+ crystal_animation = self ._get_animation_panel ()
254261 graph = Columns ([Column ([sub_layouts ["graph" ]])])
255262 controls = Columns (
256263 [
@@ -279,8 +286,8 @@ def _get_eigendisplacement(
279286 band : int = 0 ,
280287 qpoint : int = 0 ,
281288 precision : int = 15 ,
282- magnitude : int = 225 ,
283- total_repeat_cell_cnt : int | None = None ,
289+ magnitude : int = MAX_MAGNITUDE / 2 ,
290+ total_repeat_cell_cnt : int = 1 ,
284291 ) -> dict :
285292 if not ph_bs or not json_data :
286293 return {}
@@ -305,8 +312,9 @@ def calc_max_displacement(idx: int) -> list:
305312
306313 # get the atom index
307314 assert total_repeat_cell_cnt != 0
315+
308316 modified_idx = (
309- (idx // total_repeat_cell_cnt ) if total_repeat_cell_cnt else idx
317+ int (idx // total_repeat_cell_cnt ) if total_repeat_cell_cnt else idx
310318 )
311319
312320 return [
@@ -793,27 +801,7 @@ def get_figure(
793801 clickmode = "event+select" ,
794802 )
795803
796- default_red_dot = [
797- {
798- "type" : "scatter" ,
799- "mode" : "markers" ,
800- "x" : [0 ],
801- "y" : [0 ],
802- "marker" : {
803- "color" : MARKER_COLOR ,
804- "size" : MARKER_SIZE ,
805- "symbol" : MARKER_SHAPE ,
806- },
807- "name" : "click-marker" ,
808- "showlegend" : False ,
809- "customdata" : [[0 , 0 ]],
810- "hovertemplate" : (
811- "band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
812- ),
813- }
814- ]
815-
816- figure = {"data" : bs_traces + dos_traces + default_red_dot , "layout" : layout }
804+ figure = {"data" : bs_traces + dos_traces , "layout" : layout }
817805
818806 legend = dict (
819807 x = 1.02 ,
@@ -848,37 +836,37 @@ def update_graph(bs, dos, nclick):
848836 dos = CompletePhononDos .from_dict (dos )
849837
850838 figure = self .get_figure (bs , dos )
851- if nclick and nclick .get ("points" ):
852- # remove marker if there is one
853- figure ["data" ] = [
854- t for t in figure ["data" ] if t .get ("name" ) != "click-marker"
855- ]
856839
857- x_click = nclick ["points" ][0 ]["x" ]
858- y_click = nclick ["points" ][0 ]["y" ]
840+ # remove marker if there is one
841+ figure ["data" ] = [
842+ t for t in figure ["data" ] if t .get ("name" ) != "click-marker"
843+ ]
859844
860- pt = nclick ["points" ][0 ]
861- qpoint , band_num = pt .get ("customdata" , [0 , 0 ])
845+ x_click = nclick ["points" ][0 ]["x" ] if nclick else 0
846+ y_click = nclick ["points" ][0 ]["y" ] if nclick else 0
847+ pt = nclick ["points" ][0 ] if nclick else {}
862848
863- figure ["data" ].append (
864- {
865- "type" : "scatter" ,
866- "mode" : "markers" ,
867- "x" : [x_click ],
868- "y" : [y_click ],
869- "marker" : {
870- "color" : MARKER_COLOR ,
871- "size" : MARKER_SIZE ,
872- "symbol" : MARKER_SHAPE ,
873- },
874- "name" : "click-marker" ,
875- "showlegend" : False ,
876- "customdata" : [[qpoint , band_num ]],
877- "hovertemplate" : (
878- "band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
879- ),
880- }
881- )
849+ qpoint , band_num = pt .get ("customdata" , [0 , 0 ])
850+
851+ figure ["data" ].append (
852+ {
853+ "type" : "scatter" ,
854+ "mode" : "markers" ,
855+ "x" : [x_click ],
856+ "y" : [y_click ],
857+ "marker" : {
858+ "color" : MARKER_COLOR ,
859+ "size" : MARKER_SIZE ,
860+ "symbol" : MARKER_SHAPE ,
861+ },
862+ "name" : "click-marker" ,
863+ "showlegend" : False ,
864+ "customdata" : [[qpoint , band_num ]],
865+ "hovertemplate" : (
866+ "band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
867+ ),
868+ }
869+ )
882870
883871 zone_scene = self .get_brillouin_zone_scene (bs )
884872
@@ -903,35 +891,45 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
903891 Output (self .id ("crystal-animation" ), "data" ),
904892 Input (self .id ("ph-bsdos-graph" ), "clickData" ),
905893 Input (self .id ("ph_bs" ), "data" ),
906- Input (self .id ("controls-btn" ), "n_clicks" ),
907- Input (self .get_all_kwargs_id (), "value" ),
894+ Input (self .id ("supercell-controls-btn" ), "n_clicks" ),
895+ Input (self .get_kwarg_id ("magnitude" ), "value" ),
896+ State (self .get_kwarg_id ("scale-x" ), "value" ),
897+ State (self .get_kwarg_id ("scale-y" ), "value" ),
898+ State (self .get_kwarg_id ("scale-z" ), "value" ),
908899 # prevent_initial_call=True
909900 )
910- def update_crystal_animation (cd , bs , update , kwargs ):
901+ def update_crystal_animation (
902+ cd , bs , sueprcell_update , magnitude_fraction , scale_x , scale_y , scale_z
903+ ):
904+ # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields,
905+ # ensuring updates occur only after the `supercell-controls-btn`` is clicked.
906+
911907 if not bs :
912908 raise PreventUpdate
913909
910+ # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values.
911+ # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value,
912+ # this approach provides better clarity and readability.
913+ kwargs = self .reconstruct_kwargs_from_state ()
914+ magnitude_fraction = kwargs .get ("magnitude" )
915+ scale_x = kwargs .get ("scale-x" )
916+ scale_y = kwargs .get ("scale-y" )
917+ scale_z = kwargs .get ("scale-z" )
918+
914919 if isinstance (bs , dict ):
915920 bs = PhononBandStructureSymmLine .from_dict (bs )
916921
917- kwargs = self .reconstruct_kwargs_from_state ()
922+ struct = bs .structure
923+ total_repeat_cell_cnt = 1
924+ # update structure if the controls got triggered
925+ if sueprcell_update :
926+ total_repeat_cell_cnt = scale_x * scale_y * scale_z
918927
919- # animation control
920- scale_x , scale_y , scale_z = (
921- int (kwargs ["scale-x" ]),
922- int (kwargs ["scale-y" ]),
923- int (kwargs ["scale-z" ]),
924- )
925- magnitude_fraction = kwargs ["magnitude" ]
926- magnitude = (
927- MAX_MAGNITUDE - MIN_MAGNITUDE
928- ) * magnitude_fraction + MIN_MAGNITUDE
929-
930- # create supercell
931- trans = SupercellTransformation (
932- ((scale_x , 0 , 0 ), (0 , scale_y , 0 ), (0 , 0 , scale_z ))
933- )
934- struct = trans .apply_transformation (bs .structure )
928+ # create supercell
929+ trans = SupercellTransformation (
930+ ((scale_x , 0 , 0 ), (0 , scale_y , 0 ), (0 , 0 , scale_z ))
931+ )
932+ struct = trans .apply_transformation (struct )
935933
936934 struc_graph = StructureGraph .from_local_env_strategy (struct , CrystalNN ())
937935 scene = struc_graph .get_scene (
@@ -948,12 +946,17 @@ def update_crystal_animation(cd, bs, update, kwargs):
948946 pt = cd ["points" ][0 ]
949947 qpoint , band_num = pt .get ("customdata" , [0 , 0 ])
950948
949+ # magnitude
950+ magnitude = (
951+ MAX_MAGNITUDE - MIN_MAGNITUDE
952+ ) * magnitude_fraction + MIN_MAGNITUDE
953+
951954 return PhononBandstructureAndDosComponent ._get_eigendisplacement (
952955 ph_bs = bs ,
953956 json_data = json_data ,
954957 band = band_num ,
955958 qpoint = qpoint ,
956- total_repeat_cell_cnt = scale_x * scale_y * scale_z ,
959+ total_repeat_cell_cnt = total_repeat_cell_cnt ,
957960 magnitude = magnitude ,
958961 )
959962
0 commit comments