1818from pymatgen .phonon .bandstructure import PhononBandStructureSymmLine
1919from pymatgen .phonon .dos import CompletePhononDos
2020from pymatgen .phonon .plotter import PhononBSPlotter
21+ from pymatgen .transformations .standard_transformations import SupercellTransformation
2122
2223from crystal_toolkit .core .mpcomponent import MPComponent
2324from crystal_toolkit .core .panelcomponent import PanelComponent
3031 from pymatgen .electronic_structure .dos import CompleteDos
3132
3233DISPLACE_COEF = [0 , 1 , 0 , - 1 , 0 ]
34+ MARKER_COLOR = "red"
35+ MARKER_SIZE = 12
36+ MARKER_SHAPE = "x"
37+ MAX_MAGNITUDE = 400
38+ MIN_MAGNITUDE = 0
3339
3440# TODOs:
3541# - look for additional projection methods in phonon DOS (currently only atom
@@ -149,22 +155,70 @@ def _sub_layouts(self) -> dict[str, Component]:
149155
150156 # crystal visualization
151157
152- tip = html .P (
153- "Click different q-points and bands in the dispersion diagram to see the crystal vibration." ,
154- id = self .id ("crystal-tip" ),
155- style = {
156- "margin" : "0 0 12px" ,
157- "fontSize" : "16px" ,
158- "color" : "#555" ,
159- "textAlign" : "center" ,
160- },
158+ tip = html .H5 (
159+ "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!" ,
160+ )
161+
162+ crystal_animation = html .Div (
163+ CrystalToolkitAnimationScene (
164+ data = {},
165+ sceneSize = "500px" ,
166+ id = self .id ("crystal-animation" ),
167+ settings = {"defaultZoom" : 1.2 },
168+ ),
169+ style = {"width" : "60%" },
161170 )
162171
163- crystal_animation = CrystalToolkitAnimationScene (
164- data = {},
165- sceneSize = "200px" ,
166- id = self .id ("crystal-animation" ),
167- settings = {"defaultZoom" : 1.5 },
172+ crystal_animation_controls = html .Div (
173+ [
174+ html .Br (),
175+ html .Div (tip , style = {"textAlign" : "center" }),
176+ html .Br (),
177+ html .H5 ("Control Panel" , style = {"textAlign" : "center" }),
178+ html .H6 ("Supercell modification" ),
179+ html .Br (),
180+ html .Div (
181+ [
182+ self .get_numerical_input (
183+ kwarg_label = "scale-x" ,
184+ default = 1 ,
185+ is_int = True ,
186+ label = "x" ,
187+ style = {"width" : "5rem" },
188+ ),
189+ self .get_numerical_input (
190+ kwarg_label = "scale-y" ,
191+ default = 1 ,
192+ is_int = True ,
193+ label = "y" ,
194+ style = {"width" : "5rem" },
195+ ),
196+ self .get_numerical_input (
197+ kwarg_label = "scale-z" ,
198+ default = 1 ,
199+ is_int = True ,
200+ label = "z" ,
201+ style = {"width" : "5rem" },
202+ ),
203+ html .Button (
204+ "Update" ,
205+ id = self .id ("controls-btn" ),
206+ style = {"height" : "40px" },
207+ ),
208+ ],
209+ style = {"display" : "flex" },
210+ ),
211+ html .Br (),
212+ html .Div (
213+ self .get_slider_input (
214+ kwarg_label = "magnitude" ,
215+ default = 0.5 ,
216+ step = 0.01 ,
217+ domain = [0 , 1 ],
218+ label = "Vibration magnitude" ,
219+ )
220+ ),
221+ ],
168222 )
169223
170224 return {
@@ -176,12 +230,26 @@ def _sub_layouts(self) -> dict[str, Component]:
176230 "table" : summary_table ,
177231 "crystal-animation" : crystal_animation ,
178232 "tip" : tip ,
233+ "crystal-animation-controls" : crystal_animation_controls ,
179234 }
180235
181236 def layout (self ) -> html .Div :
182237 sub_layouts = self ._sub_layouts
183238 crystal_animation = Columns (
184- [Column ([sub_layouts ["tip" ], sub_layouts ["crystal-animation" ]])]
239+ [
240+ Column (
241+ [
242+ # sub_layouts["tip"],
243+ Columns (
244+ [
245+ sub_layouts ["crystal-animation" ],
246+ sub_layouts ["crystal-animation-controls" ],
247+ ]
248+ )
249+ ]
250+ ),
251+ # Column([sub_layouts["crystal-animation-controls"]])
252+ ]
185253 )
186254 graph = Columns ([Column ([sub_layouts ["graph" ]])])
187255 controls = Columns (
@@ -212,6 +280,7 @@ def _get_eigendisplacement(
212280 qpoint : int = 0 ,
213281 precision : int = 15 ,
214282 magnitude : int = 225 ,
283+ total_repeat_cell_cnt : int | None = None ,
215284 ) -> dict :
216285 if not ph_bs or not json_data :
217286 return {}
@@ -233,9 +302,16 @@ def calc_max_displacement(idx: int) -> list:
233302 This function extracts the real component of the atom's eigendisplacement,
234303 scales it by the specified magnitude, and returns the resulting vector.
235304 """
305+
306+ # get the atom index
307+ assert total_repeat_cell_cnt != 0
308+ modified_idx = (
309+ (idx // total_repeat_cell_cnt ) if total_repeat_cell_cnt else idx
310+ )
311+
236312 return [
237313 round (complex (vec ).real * magnitude , precision )
238- for vec in ph_bs .eigendisplacements [band ][qpoint ][idx ]
314+ for vec in ph_bs .eigendisplacements [band ][qpoint ][modified_idx ]
239315 ]
240316
241317 def calc_animation_step (max_displacement : list , coef : int ) -> list :
@@ -717,7 +793,27 @@ def get_figure(
717793 clickmode = "event+select" ,
718794 )
719795
720- figure = {"data" : bs_traces + dos_traces , "layout" : layout }
796+ default_red_dot = [
797+ {
798+ "type" : "scatter" ,
799+ "mode" : "markers" ,
800+ "x" : [0 ],
801+ "y" : [0 ],
802+ "marker" : {
803+ "color" : MARKER_COLOR ,
804+ "size" : MARKER_SIZE ,
805+ "symbol" : MARKER_SHAPE ,
806+ },
807+ "name" : "click-marker" ,
808+ "showlegend" : False ,
809+ "customdata" : [[0 , 0 ]],
810+ "hovertemplate" : (
811+ "band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
812+ ),
813+ }
814+ ]
815+
816+ figure = {"data" : bs_traces + dos_traces + default_red_dot , "layout" : layout }
721817
722818 legend = dict (
723819 x = 1.02 ,
@@ -743,14 +839,46 @@ def generate_callbacks(self, app, cache) -> None:
743839 Output (self .id ("table" ), "children" ),
744840 Input (self .id ("ph_bs" ), "data" ),
745841 Input (self .id ("ph_dos" ), "data" ),
842+ Input (self .id ("ph-bsdos-graph" ), "clickData" ),
746843 )
747- def update_graph (bs , dos ):
844+ def update_graph (bs , dos , nclick ):
748845 if isinstance (bs , dict ):
749846 bs = PhononBandStructureSymmLine .from_dict (bs )
750847 if isinstance (dos , dict ):
751848 dos = CompletePhononDos .from_dict (dos )
752849
753850 figure = self .get_figure (bs , dos )
851+ if nclick and nclick .get ("points" ):
852+ # remove marker if there is one
853+ figure ["data" ] = [
854+ t for t in figure ["data" ] if t .get ("name" ) != "click-marker"
855+ ]
856+
857+ x_click = nclick ["points" ][0 ]["x" ]
858+ y_click = nclick ["points" ][0 ]["y" ]
859+
860+ pt = nclick ["points" ][0 ]
861+ qpoint , band_num = pt .get ("customdata" , [0 , 0 ])
862+
863+ figure ["data" ].append (
864+ {
865+ "type" : "scatter" ,
866+ "mode" : "markers" ,
867+ "x" : [x_click ],
868+ "y" : [y_click ],
869+ "marker" : {
870+ "color" : MARKER_COLOR ,
871+ "size" : MARKER_SIZE ,
872+ "symbol" : MARKER_SHAPE ,
873+ },
874+ "name" : "click-marker" ,
875+ "showlegend" : False ,
876+ "customdata" : [[qpoint , band_num ]],
877+ "hovertemplate" : (
878+ "band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
879+ ),
880+ }
881+ )
754882
755883 zone_scene = self .get_brillouin_zone_scene (bs )
756884
@@ -775,18 +903,37 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
775903 Output (self .id ("crystal-animation" ), "data" ),
776904 Input (self .id ("ph-bsdos-graph" ), "clickData" ),
777905 Input (self .id ("ph_bs" ), "data" ),
906+ Input (self .id ("controls-btn" ), "n_clicks" ),
907+ Input (self .get_all_kwargs_id (), "value" ),
778908 # prevent_initial_call=True
779909 )
780- def update_crystal_animation (cd , bs ):
910+ def update_crystal_animation (cd , bs , update , kwargs ):
781911 if not bs :
782912 raise PreventUpdate
783913
784914 if isinstance (bs , dict ):
785915 bs = PhononBandStructureSymmLine .from_dict (bs )
786916
787- struc_graph = StructureGraph .from_local_env_strategy (
788- bs .structure , CrystalNN ()
917+ kwargs = self .reconstruct_kwargs_from_state ()
918+
919+ # animation control
920+ scale_x , scale_y , scale_z = (
921+ int (kwargs ["scale-x" ]),
922+ int (kwargs ["scale-y" ]),
923+ int (kwargs ["scale-z" ]),
924+ )
925+ magnitude_fraction = kwargs ["magnitude" ]
926+ magnitude = (
927+ MAX_MAGNITUDE - MIN_MAGNITUDE
928+ ) * magnitude_fraction + MIN_MAGNITUDE
929+
930+ # create supercell
931+ trans = SupercellTransformation (
932+ ((scale_x , 0 , 0 ), (0 , scale_y , 0 ), (0 , 0 , scale_z ))
789933 )
934+ struct = trans .apply_transformation (bs .structure )
935+
936+ struc_graph = StructureGraph .from_local_env_strategy (struct , CrystalNN ())
790937 scene = struc_graph .get_scene (
791938 draw_image_atoms = False ,
792939 bonded_sites_outside_unit_cell = False ,
@@ -806,6 +953,8 @@ def update_crystal_animation(cd, bs):
806953 json_data = json_data ,
807954 band = band_num ,
808955 qpoint = qpoint ,
956+ total_repeat_cell_cnt = scale_x * scale_y * scale_z ,
957+ magnitude = magnitude ,
809958 )
810959
811960
0 commit comments