Skip to content

Commit c706707

Browse files
Chiu PeterChiu Peter
authored andcommitted
add supercell construction
1 parent 19573d0 commit c706707

File tree

1 file changed

+170
-21
lines changed

1 file changed

+170
-21
lines changed

crystal_toolkit/components/phonon.py

Lines changed: 170 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
1919
from pymatgen.phonon.dos import CompletePhononDos
2020
from pymatgen.phonon.plotter import PhononBSPlotter
21+
from pymatgen.transformations.standard_transformations import SupercellTransformation
2122

2223
from crystal_toolkit.core.mpcomponent import MPComponent
2324
from crystal_toolkit.core.panelcomponent import PanelComponent
@@ -30,6 +31,11 @@
3031
from pymatgen.electronic_structure.dos import CompleteDos
3132

3233
DISPLACE_COEF = [0, 1, 0, -1, 0]
34+
MARKER_COLOR = "red"
35+
MARKER_SIZE = 12
36+
MARKER_SHAPE = "x"
37+
MAX_MAGNITUDE = 400
38+
MIN_MAGNITUDE = 0
3339

3440
# TODOs:
3541
# - look for additional projection methods in phonon DOS (currently only atom
@@ -149,22 +155,70 @@ def _sub_layouts(self) -> dict[str, Component]:
149155

150156
# crystal visualization
151157

152-
tip = html.P(
153-
"Click different q-points and bands in the dispersion diagram to see the crystal vibration.",
154-
id=self.id("crystal-tip"),
155-
style={
156-
"margin": "0 0 12px",
157-
"fontSize": "16px",
158-
"color": "#555",
159-
"textAlign": "center",
160-
},
158+
tip = html.H5(
159+
"💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!",
160+
)
161+
162+
crystal_animation = html.Div(
163+
CrystalToolkitAnimationScene(
164+
data={},
165+
sceneSize="500px",
166+
id=self.id("crystal-animation"),
167+
settings={"defaultZoom": 1.2},
168+
),
169+
style={"width": "60%"},
161170
)
162171

163-
crystal_animation = CrystalToolkitAnimationScene(
164-
data={},
165-
sceneSize="200px",
166-
id=self.id("crystal-animation"),
167-
settings={"defaultZoom": 1.5},
172+
crystal_animation_controls = html.Div(
173+
[
174+
html.Br(),
175+
html.Div(tip, style={"textAlign": "center"}),
176+
html.Br(),
177+
html.H5("Control Panel", style={"textAlign": "center"}),
178+
html.H6("Supercell modification"),
179+
html.Br(),
180+
html.Div(
181+
[
182+
self.get_numerical_input(
183+
kwarg_label="scale-x",
184+
default=1,
185+
is_int=True,
186+
label="x",
187+
style={"width": "5rem"},
188+
),
189+
self.get_numerical_input(
190+
kwarg_label="scale-y",
191+
default=1,
192+
is_int=True,
193+
label="y",
194+
style={"width": "5rem"},
195+
),
196+
self.get_numerical_input(
197+
kwarg_label="scale-z",
198+
default=1,
199+
is_int=True,
200+
label="z",
201+
style={"width": "5rem"},
202+
),
203+
html.Button(
204+
"Update",
205+
id=self.id("controls-btn"),
206+
style={"height": "40px"},
207+
),
208+
],
209+
style={"display": "flex"},
210+
),
211+
html.Br(),
212+
html.Div(
213+
self.get_slider_input(
214+
kwarg_label="magnitude",
215+
default=0.5,
216+
step=0.01,
217+
domain=[0, 1],
218+
label="Vibration magnitude",
219+
)
220+
),
221+
],
168222
)
169223

170224
return {
@@ -176,12 +230,26 @@ def _sub_layouts(self) -> dict[str, Component]:
176230
"table": summary_table,
177231
"crystal-animation": crystal_animation,
178232
"tip": tip,
233+
"crystal-animation-controls": crystal_animation_controls,
179234
}
180235

181236
def layout(self) -> html.Div:
182237
sub_layouts = self._sub_layouts
183238
crystal_animation = Columns(
184-
[Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])]
239+
[
240+
Column(
241+
[
242+
# sub_layouts["tip"],
243+
Columns(
244+
[
245+
sub_layouts["crystal-animation"],
246+
sub_layouts["crystal-animation-controls"],
247+
]
248+
)
249+
]
250+
),
251+
# Column([sub_layouts["crystal-animation-controls"]])
252+
]
185253
)
186254
graph = Columns([Column([sub_layouts["graph"]])])
187255
controls = Columns(
@@ -212,6 +280,7 @@ def _get_eigendisplacement(
212280
qpoint: int = 0,
213281
precision: int = 15,
214282
magnitude: int = 225,
283+
total_repeat_cell_cnt: int | None = None,
215284
) -> dict:
216285
if not ph_bs or not json_data:
217286
return {}
@@ -233,9 +302,16 @@ def calc_max_displacement(idx: int) -> list:
233302
This function extracts the real component of the atom's eigendisplacement,
234303
scales it by the specified magnitude, and returns the resulting vector.
235304
"""
305+
306+
# get the atom index
307+
assert total_repeat_cell_cnt != 0
308+
modified_idx = (
309+
(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx
310+
)
311+
236312
return [
237313
round(complex(vec).real * magnitude, precision)
238-
for vec in ph_bs.eigendisplacements[band][qpoint][idx]
314+
for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx]
239315
]
240316

241317
def calc_animation_step(max_displacement: list, coef: int) -> list:
@@ -717,7 +793,27 @@ def get_figure(
717793
clickmode="event+select",
718794
)
719795

720-
figure = {"data": bs_traces + dos_traces, "layout": layout}
796+
default_red_dot = [
797+
{
798+
"type": "scatter",
799+
"mode": "markers",
800+
"x": [0],
801+
"y": [0],
802+
"marker": {
803+
"color": MARKER_COLOR,
804+
"size": MARKER_SIZE,
805+
"symbol": MARKER_SHAPE,
806+
},
807+
"name": "click-marker",
808+
"showlegend": False,
809+
"customdata": [[0, 0]],
810+
"hovertemplate": (
811+
"band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
812+
),
813+
}
814+
]
815+
816+
figure = {"data": bs_traces + dos_traces + default_red_dot, "layout": layout}
721817

722818
legend = dict(
723819
x=1.02,
@@ -743,14 +839,46 @@ def generate_callbacks(self, app, cache) -> None:
743839
Output(self.id("table"), "children"),
744840
Input(self.id("ph_bs"), "data"),
745841
Input(self.id("ph_dos"), "data"),
842+
Input(self.id("ph-bsdos-graph"), "clickData"),
746843
)
747-
def update_graph(bs, dos):
844+
def update_graph(bs, dos, nclick):
748845
if isinstance(bs, dict):
749846
bs = PhononBandStructureSymmLine.from_dict(bs)
750847
if isinstance(dos, dict):
751848
dos = CompletePhononDos.from_dict(dos)
752849

753850
figure = self.get_figure(bs, dos)
851+
if nclick and nclick.get("points"):
852+
# remove marker if there is one
853+
figure["data"] = [
854+
t for t in figure["data"] if t.get("name") != "click-marker"
855+
]
856+
857+
x_click = nclick["points"][0]["x"]
858+
y_click = nclick["points"][0]["y"]
859+
860+
pt = nclick["points"][0]
861+
qpoint, band_num = pt.get("customdata", [0, 0])
862+
863+
figure["data"].append(
864+
{
865+
"type": "scatter",
866+
"mode": "markers",
867+
"x": [x_click],
868+
"y": [y_click],
869+
"marker": {
870+
"color": MARKER_COLOR,
871+
"size": MARKER_SIZE,
872+
"symbol": MARKER_SHAPE,
873+
},
874+
"name": "click-marker",
875+
"showlegend": False,
876+
"customdata": [[qpoint, band_num]],
877+
"hovertemplate": (
878+
"band: %{customdata[1]}<br>q-point: %{customdata[0]}<br>"
879+
),
880+
}
881+
)
754882

755883
zone_scene = self.get_brillouin_zone_scene(bs)
756884

@@ -775,18 +903,37 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
775903
Output(self.id("crystal-animation"), "data"),
776904
Input(self.id("ph-bsdos-graph"), "clickData"),
777905
Input(self.id("ph_bs"), "data"),
906+
Input(self.id("controls-btn"), "n_clicks"),
907+
Input(self.get_all_kwargs_id(), "value"),
778908
# prevent_initial_call=True
779909
)
780-
def update_crystal_animation(cd, bs):
910+
def update_crystal_animation(cd, bs, update, kwargs):
781911
if not bs:
782912
raise PreventUpdate
783913

784914
if isinstance(bs, dict):
785915
bs = PhononBandStructureSymmLine.from_dict(bs)
786916

787-
struc_graph = StructureGraph.from_local_env_strategy(
788-
bs.structure, CrystalNN()
917+
kwargs = self.reconstruct_kwargs_from_state()
918+
919+
# animation control
920+
scale_x, scale_y, scale_z = (
921+
int(kwargs["scale-x"]),
922+
int(kwargs["scale-y"]),
923+
int(kwargs["scale-z"]),
924+
)
925+
magnitude_fraction = kwargs["magnitude"]
926+
magnitude = (
927+
MAX_MAGNITUDE - MIN_MAGNITUDE
928+
) * magnitude_fraction + MIN_MAGNITUDE
929+
930+
# create supercell
931+
trans = SupercellTransformation(
932+
((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z))
789933
)
934+
struct = trans.apply_transformation(bs.structure)
935+
936+
struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN())
790937
scene = struc_graph.get_scene(
791938
draw_image_atoms=False,
792939
bonded_sites_outside_unit_cell=False,
@@ -806,6 +953,8 @@ def update_crystal_animation(cd, bs):
806953
json_data=json_data,
807954
band=band_num,
808955
qpoint=qpoint,
956+
total_repeat_cell_cnt=scale_x * scale_y * scale_z,
957+
magnitude=magnitude,
809958
)
810959

811960

0 commit comments

Comments
 (0)