Skip to content

Commit 3331e3d

Browse files
Chiu PeterChiu Peter
authored andcommitted
pre rebase
1 parent c706707 commit 3331e3d

File tree

1 file changed

+83
-80
lines changed

1 file changed

+83
-80
lines changed

crystal_toolkit/components/phonon.py

Lines changed: 83 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import plotly.graph_objects as go
99
from dash import dcc, html
10-
from dash.dependencies import Component, Input, Output
10+
from dash.dependencies import Component, Input, Output, State
1111
from dash.exceptions import PreventUpdate
1212
from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene
1313

@@ -34,7 +34,7 @@
3434
MARKER_COLOR = "red"
3535
MARKER_SIZE = 12
3636
MARKER_SHAPE = "x"
37-
MAX_MAGNITUDE = 400
37+
MAX_MAGNITUDE = 300
3838
MIN_MAGNITUDE = 0
3939

4040
# TODOs:
@@ -165,6 +165,8 @@ def _sub_layouts(self) -> dict[str, Component]:
165165
sceneSize="500px",
166166
id=self.id("crystal-animation"),
167167
settings={"defaultZoom": 1.2},
168+
axisView="SW",
169+
showControls=False, # disable download for now
168170
),
169171
style={"width": "60%"},
170172
)
@@ -184,25 +186,28 @@ def _sub_layouts(self) -> dict[str, Component]:
184186
default=1,
185187
is_int=True,
186188
label="x",
189+
min=1,
187190
style={"width": "5rem"},
188191
),
189192
self.get_numerical_input(
190193
kwarg_label="scale-y",
191194
default=1,
192195
is_int=True,
193196
label="y",
197+
min=1,
194198
style={"width": "5rem"},
195199
),
196200
self.get_numerical_input(
197201
kwarg_label="scale-z",
198202
default=1,
199203
is_int=True,
200204
label="z",
205+
min=1,
201206
style={"width": "5rem"},
202207
),
203208
html.Button(
204209
"Update",
205-
id=self.id("controls-btn"),
210+
id=self.id("supercell-controls-btn"),
206211
style={"height": "40px"},
207212
),
208213
],
@@ -233,13 +238,12 @@ def _sub_layouts(self) -> dict[str, Component]:
233238
"crystal-animation-controls": crystal_animation_controls,
234239
}
235240

236-
def layout(self) -> html.Div:
241+
def _get_animation_panel(self):
237242
sub_layouts = self._sub_layouts
238-
crystal_animation = Columns(
243+
return Columns(
239244
[
240245
Column(
241246
[
242-
# sub_layouts["tip"],
243247
Columns(
244248
[
245249
sub_layouts["crystal-animation"],
@@ -248,9 +252,12 @@ def layout(self) -> html.Div:
248252
)
249253
]
250254
),
251-
# Column([sub_layouts["crystal-animation-controls"]])
252255
]
253256
)
257+
258+
def layout(self) -> html.Div:
259+
sub_layouts = self._sub_layouts
260+
crystal_animation = self._get_animation_panel()
254261
graph = Columns([Column([sub_layouts["graph"]])])
255262
controls = Columns(
256263
[
@@ -279,8 +286,8 @@ def _get_eigendisplacement(
279286
band: int = 0,
280287
qpoint: int = 0,
281288
precision: int = 15,
282-
magnitude: int = 225,
283-
total_repeat_cell_cnt: int | None = None,
289+
magnitude: int = MAX_MAGNITUDE / 2,
290+
total_repeat_cell_cnt: int = 1,
284291
) -> dict:
285292
if not ph_bs or not json_data:
286293
return {}
@@ -305,8 +312,9 @@ def calc_max_displacement(idx: int) -> list:
305312

306313
# get the atom index
307314
assert total_repeat_cell_cnt != 0
315+
308316
modified_idx = (
309-
(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx
317+
int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx
310318
)
311319

312320
return [
@@ -793,27 +801,7 @@ def get_figure(
793801
clickmode="event+select",
794802
)
795803

796-
default_red_dot = [
797-
{
798-
"type": "scatter",
799-
"mode": "markers",
800-
"x": [0],
801-
"y": [0],
802-
"marker": {
803-
"color": MARKER_COLOR,
804-
"size": MARKER_SIZE,
805-
"symbol": MARKER_SHAPE,
806-
},
807-
"name": "click-marker",
808-
"showlegend": False,
809-
"customdata": [[0, 0]],
810-
"hovertemplate": (
811-
"band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
812-
),
813-
}
814-
]
815-
816-
figure = {"data": bs_traces + dos_traces + default_red_dot, "layout": layout}
804+
figure = {"data": bs_traces + dos_traces, "layout": layout}
817805

818806
legend = dict(
819807
x=1.02,
@@ -848,37 +836,37 @@ def update_graph(bs, dos, nclick):
848836
dos = CompletePhononDos.from_dict(dos)
849837

850838
figure = self.get_figure(bs, dos)
851-
if nclick and nclick.get("points"):
852-
# remove marker if there is one
853-
figure["data"] = [
854-
t for t in figure["data"] if t.get("name") != "click-marker"
855-
]
856839

857-
x_click = nclick["points"][0]["x"]
858-
y_click = nclick["points"][0]["y"]
840+
# remove marker if there is one
841+
figure["data"] = [
842+
t for t in figure["data"] if t.get("name") != "click-marker"
843+
]
859844

860-
pt = nclick["points"][0]
861-
qpoint, band_num = pt.get("customdata", [0, 0])
845+
x_click = nclick["points"][0]["x"] if nclick else 0
846+
y_click = nclick["points"][0]["y"] if nclick else 0
847+
pt = nclick["points"][0] if nclick else {}
862848

863-
figure["data"].append(
864-
{
865-
"type": "scatter",
866-
"mode": "markers",
867-
"x": [x_click],
868-
"y": [y_click],
869-
"marker": {
870-
"color": MARKER_COLOR,
871-
"size": MARKER_SIZE,
872-
"symbol": MARKER_SHAPE,
873-
},
874-
"name": "click-marker",
875-
"showlegend": False,
876-
"customdata": [[qpoint, band_num]],
877-
"hovertemplate": (
878-
"band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
879-
),
880-
}
881-
)
849+
qpoint, band_num = pt.get("customdata", [0, 0])
850+
851+
figure["data"].append(
852+
{
853+
"type": "scatter",
854+
"mode": "markers",
855+
"x": [x_click],
856+
"y": [y_click],
857+
"marker": {
858+
"color": MARKER_COLOR,
859+
"size": MARKER_SIZE,
860+
"symbol": MARKER_SHAPE,
861+
},
862+
"name": "click-marker",
863+
"showlegend": False,
864+
"customdata": [[qpoint, band_num]],
865+
"hovertemplate": (
866+
"band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
867+
),
868+
}
869+
)
882870

883871
zone_scene = self.get_brillouin_zone_scene(bs)
884872

@@ -903,35 +891,45 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
903891
Output(self.id("crystal-animation"), "data"),
904892
Input(self.id("ph-bsdos-graph"), "clickData"),
905893
Input(self.id("ph_bs"), "data"),
906-
Input(self.id("controls-btn"), "n_clicks"),
907-
Input(self.get_all_kwargs_id(), "value"),
894+
Input(self.id("supercell-controls-btn"), "n_clicks"),
895+
Input(self.get_kwarg_id("magnitude"), "value"),
896+
State(self.get_kwarg_id("scale-x"), "value"),
897+
State(self.get_kwarg_id("scale-y"), "value"),
898+
State(self.get_kwarg_id("scale-z"), "value"),
908899
# prevent_initial_call=True
909900
)
910-
def update_crystal_animation(cd, bs, update, kwargs):
901+
def update_crystal_animation(
902+
cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z
903+
):
904+
# Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields,
905+
# ensuring updates occur only after the `supercell-controls-btn`` is clicked.
906+
911907
if not bs:
912908
raise PreventUpdate
913909

910+
# Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values.
911+
# Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value,
912+
# this approach provides better clarity and readability.
913+
kwargs = self.reconstruct_kwargs_from_state()
914+
magnitude_fraction = kwargs.get("magnitude")
915+
scale_x = kwargs.get("scale-x")
916+
scale_y = kwargs.get("scale-y")
917+
scale_z = kwargs.get("scale-z")
918+
914919
if isinstance(bs, dict):
915920
bs = PhononBandStructureSymmLine.from_dict(bs)
916921

917-
kwargs = self.reconstruct_kwargs_from_state()
922+
struct = bs.structure
923+
total_repeat_cell_cnt = 1
924+
# update structure if the controls got triggered
925+
if sueprcell_update:
926+
total_repeat_cell_cnt = scale_x * scale_y * scale_z
918927

919-
# animation control
920-
scale_x, scale_y, scale_z = (
921-
int(kwargs["scale-x"]),
922-
int(kwargs["scale-y"]),
923-
int(kwargs["scale-z"]),
924-
)
925-
magnitude_fraction = kwargs["magnitude"]
926-
magnitude = (
927-
MAX_MAGNITUDE - MIN_MAGNITUDE
928-
) * magnitude_fraction + MIN_MAGNITUDE
929-
930-
# create supercell
931-
trans = SupercellTransformation(
932-
((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z))
933-
)
934-
struct = trans.apply_transformation(bs.structure)
928+
# create supercell
929+
trans = SupercellTransformation(
930+
((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z))
931+
)
932+
struct = trans.apply_transformation(struct)
935933

936934
struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN())
937935
scene = struc_graph.get_scene(
@@ -948,12 +946,17 @@ def update_crystal_animation(cd, bs, update, kwargs):
948946
pt = cd["points"][0]
949947
qpoint, band_num = pt.get("customdata", [0, 0])
950948

949+
# magnitude
950+
magnitude = (
951+
MAX_MAGNITUDE - MIN_MAGNITUDE
952+
) * magnitude_fraction + MIN_MAGNITUDE
953+
951954
return PhononBandstructureAndDosComponent._get_eigendisplacement(
952955
ph_bs=bs,
953956
json_data=json_data,
954957
band=band_num,
955958
qpoint=qpoint,
956-
total_repeat_cell_cnt=scale_x * scale_y * scale_z,
959+
total_repeat_cell_cnt=total_repeat_cell_cnt,
957960
magnitude=magnitude,
958961
)
959962

0 commit comments

Comments
 (0)