Skip to content

Commit 0aa6e25

Browse files
authored
Merge branch 'main' into phonon_v2
2 parents b31a1f6 + bd75b78 commit 0aa6e25

File tree

9 files changed

+448
-435
lines changed

9 files changed

+448
-435
lines changed

.github/workflows/upgrade-dependencies.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ jobs:
2323
- name: Upgrade Python dependencies
2424
shell: bash
2525
run: |
26-
python${{ matrix.python-version }} -m pip install --upgrade pip pip-tools wheel cython setuptools
26+
python${{ matrix.python-version }} -m pip install --upgrade "pip<25.3" pip-tools wheel cython setuptools
2727
python${{ matrix.python-version }} -m pip install `grep numpy== requirements/${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt`
28-
python${{ matrix.python-version }} -m piptools compile -q --upgrade --resolver=backtracking -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt pyproject.toml
29-
python${{ matrix.python-version }} -m piptools compile -q --upgrade --resolver=backtracking --all-extras -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt pyproject.toml
28+
python${{ matrix.python-version }} -m piptools compile -q --upgrade -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt pyproject.toml
29+
python${{ matrix.python-version }} -m piptools compile -q --upgrade --all-extras -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt pyproject.toml
3030
- name: Detect changes
3131
id: changes
3232
shell: bash

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,3 @@ build/*
1313
*.egg-info
1414
dist/*
1515
.vscode
16-
17-
.venv/
18-
_version.py

crystal_toolkit/components/phonon.py

Lines changed: 119 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
from __future__ import annotations
22

33
import itertools
4-
from copy import deepcopy
54
from typing import TYPE_CHECKING, Any
65

76
import numpy as np
87
import plotly.graph_objects as go
98
from dash import dcc, html
109
from dash.dependencies import Component, Input, Output, State
1110
from dash.exceptions import PreventUpdate
12-
from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene
13-
14-
# crystal animation algo
15-
from pymatgen.analysis.graphs import StructureGraph
16-
from pymatgen.analysis.local_env import CrystalNN
11+
from dash_mp_components import CrystalToolkitScene
1712
from pymatgen.ext.matproj import MPRester
1813
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
1914
from pymatgen.phonon.dos import CompletePhononDos
@@ -23,7 +18,14 @@
2318
from crystal_toolkit.core.mpcomponent import MPComponent
2419
from crystal_toolkit.core.panelcomponent import PanelComponent
2520
from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres
26-
from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list
21+
from crystal_toolkit.helpers.layouts import (
22+
Column,
23+
Columns,
24+
Label,
25+
MessageBody,
26+
MessageContainer,
27+
get_data_list,
28+
)
2729
from crystal_toolkit.helpers.pretty_labels import pretty_labels
2830

2931
if TYPE_CHECKING:
@@ -66,32 +68,26 @@ def __init__(
6668
**kwargs,
6769
)
6870

69-
bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos(
70-
self.initial_data["default"]
71-
)
72-
self.create_store("bs-store", bs)
73-
self.create_store("bs", None)
74-
self.create_store("dos", None)
75-
7671
@property
7772
def _sub_layouts(self) -> dict[str, Component]:
7873
# defaults
7974
state = {"label-select": "sc", "dos-select": "ap"}
8075

81-
fig = PhononBandstructureAndDosComponent.get_figure(None, None)
76+
bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos(
77+
self.initial_data["default"]
78+
)
79+
fig = PhononBandstructureAndDosComponent.get_figure(bs, dos)
8280
# Main plot
8381
graph = dcc.Graph(
8482
figure=fig,
8583
config={"displayModeBar": False},
86-
responsive=False,
84+
responsive=True,
8785
id=self.id("ph-bsdos-graph"),
8886
)
8987

9088
# Brillouin zone
91-
zone_scene = self.get_brillouin_zone_scene(None)
92-
zone = CrystalToolkitScene(
93-
data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone")
94-
)
89+
zone_scene = self.get_brillouin_zone_scene(bs)
90+
zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px")
9591

9692
# Hide by default if not loaded by mpid, switching between k-paths
9793
# on-the-fly only supported for bandstructures retrieved from MP
@@ -113,11 +109,9 @@ def _sub_layouts(self) -> dict[str, Component]:
113109
options=options,
114110
)
115111
],
116-
style=(
117-
{"width": "200px"}
118-
if show_path_options
119-
else {"maxWidth": "200", "display": "none"}
120-
),
112+
style={"width": "200px"}
113+
if show_path_options
114+
else {"maxWidth": "200", "display": "none"},
121115
id=self.id("path-container"),
122116
)
123117

@@ -132,11 +126,9 @@ def _sub_layouts(self) -> dict[str, Component]:
132126
options=options,
133127
)
134128
],
135-
style=(
136-
{"width": "200px"}
137-
if show_path_options
138-
else {"width": "200px", "display": "none"}
139-
),
129+
style={"width": "200px"}
130+
if show_path_options
131+
else {"width": "200px", "display": "none"},
140132
id=self.id("label-container"),
141133
)
142134

@@ -150,7 +142,7 @@ def _sub_layouts(self) -> dict[str, Component]:
150142
style={"width": "200px"},
151143
)
152144

153-
summary_dict = self._get_data_list_dict(None, None)
145+
summary_dict = self._get_data_list_dict(bs, dos)
154146
summary_table = get_data_list(summary_dict)
155147

156148
# crystal visualization
@@ -272,7 +264,7 @@ def layout(self) -> html.Div:
272264
)
273265
brillouin_zone = Columns(
274266
[
275-
Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")),
267+
Column([Label("Summary"), sub_layouts["table"]]),
276268
Column([Label("Brillouin Zone"), sub_layouts["zone"]]),
277269
]
278270
)
@@ -541,7 +533,6 @@ def get_ph_bandstructure_traces(bs, freq_range):
541533
"line": {"color": "#1f77b4"},
542534
"hoverinfo": "skip",
543535
"name": "Total",
544-
"customdata": [[di, band_num] for di in range(len(x_dat))],
545536
"hovertemplate": "%{y:.2f} THz",
546537
"showlegend": False,
547538
"xaxis": "x",
@@ -587,9 +578,6 @@ def get_ph_bandstructure_traces(bs, freq_range):
587578
def _get_data_list_dict(
588579
bs: PhononBandStructureSymmLine, dos: CompletePhononDos
589580
) -> dict[str, str | bool | int]:
590-
if (not bs) and (not dos):
591-
return {}
592-
593581
bs_minpoint, bs_min_freq = bs.min_freq()
594582
min_freq_report = (
595583
f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}"
@@ -615,7 +603,7 @@ def _get_data_list_dict(
615603
target="blank",
616604
),
617605
]
618-
): ("Yes" if bs.has_nac else "No"),
606+
): "Yes" if bs.has_nac else "No",
619607
"Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No",
620608
"Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No",
621609
"Min frequency": min_freq_report,
@@ -685,9 +673,14 @@ def get_figure(
685673
ph_dos: CompletePhononDos | None = None,
686674
freq_range: tuple[float | None, float | None] = (None, None),
687675
) -> go.Figure:
676+
if freq_range[0] is None:
677+
freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1])
678+
679+
if freq_range[1] is None:
680+
freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05)
681+
688682
if (not ph_dos) and (not ph_bs):
689683
empty_plot_style = {
690-
"height": 500,
691684
"xaxis": {"visible": False},
692685
"yaxis": {"visible": False},
693686
"paper_bgcolor": "rgba(0,0,0,0)",
@@ -696,12 +689,6 @@ def get_figure(
696689

697690
return go.Figure(layout=empty_plot_style)
698691

699-
if freq_range[0] is None:
700-
freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1])
701-
702-
if freq_range[1] is None:
703-
freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05)
704-
705692
if ph_bs:
706693
(
707694
bs_traces,
@@ -798,7 +785,7 @@ def get_figure(
798785
paper_bgcolor="rgba(0,0,0,0)",
799786
plot_bgcolor="rgba(230,230,230,230)",
800787
margin=dict(l=60, b=50, t=50, pad=0, r=30),
801-
clickmode="event+select",
788+
# clickmode="event+select"
802789
)
803790

804791
figure = {"data": bs_traces + dos_traces, "layout": layout}
@@ -836,6 +823,9 @@ def update_graph(bs, dos, nclick):
836823
dos = CompletePhononDos.from_dict(dos)
837824

838825
figure = self.get_figure(bs, dos)
826+
return dcc.Graph(
827+
figure=figure, config={"displayModeBar": False}, responsive=True
828+
)
839829

840830
# remove marker if there is one
841831
figure["data"] = [
@@ -870,10 +860,91 @@ def update_graph(bs, dos, nclick):
870860

871861
zone_scene = self.get_brillouin_zone_scene(bs)
872862

873-
summary_dict = self._get_data_list_dict(bs, dos)
874-
summary_table = get_data_list(summary_dict)
863+
return label_value, label_style
864+
865+
@app.callback(
866+
Output(self.id("dos-select"), "options"),
867+
Output(self.id("path-convention"), "options"),
868+
Output(self.id("path-container"), "style"),
869+
Input(self.id("elements"), "data"),
870+
Input(self.id("mpid"), "data"),
871+
)
872+
def update_select(elements, mpid):
873+
if elements is None:
874+
raise PreventUpdate
875+
if not mpid:
876+
dos_options = (
877+
[{"label": "Element Projected", "value": "ap"}]
878+
+ [{"label": "Orbital Projected - Total", "value": "op"}]
879+
+ [
880+
{
881+
"label": "Orbital Projected - " + str(ele_label),
882+
"value": "orb" + str(ele_label),
883+
}
884+
for ele_label in elements
885+
]
886+
)
887+
888+
path_options = [{"label": "N/A", "value": "sc"}]
889+
path_style = {"maxWidth": "200", "display": "none"}
890+
891+
return dos_options, path_options, path_style
892+
dos_options = (
893+
[{"label": "Element Projected", "value": "ap"}]
894+
+ [{"label": "Orbital Projected - Total", "value": "op"}]
895+
+ [
896+
{
897+
"label": "Orbital Projected - " + str(ele_label),
898+
"value": "orb" + str(ele_label),
899+
}
900+
for ele_label in elements
901+
]
902+
)
903+
904+
path_options = [
905+
{"label": "Setyawan-Curtarolo", "value": "sc"},
906+
{"label": "Latimer-Munro", "value": "lm"},
907+
{"label": "Hinuma et al.", "value": "hin"},
908+
]
909+
910+
path_style = {"maxWidth": "200"}
911+
912+
return dos_options, path_options, path_style
913+
914+
@app.callback(
915+
Output(self.id("traces"), "data"),
916+
Output(self.id("elements"), "data"),
917+
Input(self.id(), "data"),
918+
Input(self.id("path-convention"), "value"),
919+
Input(self.id("dos-select"), "value"),
920+
Input(self.id("label-select"), "value"),
921+
)
922+
def bs_dos_data(data, dos_select, label_select):
923+
# Obtain bands to plot over and generate traces for bs data:
924+
energy_window = (-6.0, 10.0)
925+
926+
traces = []
927+
928+
bsml, density_of_states = self._get_ph_bs_dos(data)
929+
930+
if self.bandstructure_symm_line:
931+
bs_traces = self.get_ph_bandstructure_traces(
932+
bsml, freq_range=energy_window
933+
)
934+
traces.append(bs_traces)
935+
936+
if self.density_of_states:
937+
dos_traces = self.get_ph_dos_traces(
938+
density_of_states, freq_range=energy_window
939+
)
940+
traces.append(dos_traces)
941+
942+
# traces = [bs_traces, dos_traces, bs_data]
943+
944+
# TODO: not tested if this is correct way to get element list
945+
elements = list(map(str, density_of_states.get_element_dos()))
875946

876-
return figure, zone_scene.to_json(), summary_table
947+
return traces, elements
877948

878949
@app.callback(
879950
Output(self.id("brillouin-zone"), "data"),

0 commit comments

Comments
 (0)