Skip to content

Commit 2dd38d2

Browse files
committed
Updates to BS and XRD components
1 parent 30b27a3 commit 2dd38d2

File tree

2 files changed

+191
-172
lines changed

2 files changed

+191
-172
lines changed

crystal_toolkit/components/bandstructure.py

Lines changed: 96 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,25 @@ def _sub_layouts(self):
138138
}
139139

140140
def layout(self):
141+
sub_layouts = self._sub_layouts
141142
return html.Div(
142143
[
143-
Columns([Column([self._sub_layouts["graph"]])]),
144+
Columns([Column([sub_layouts["graph"]])]),
144145
Columns(
145146
[
146147
Column(
147148
[
148-
self._sub_layouts["convention"],
149-
self._sub_layouts["label-select"],
150-
self._sub_layouts["dos-select"],
149+
sub_layouts["convention"],
150+
sub_layouts["label-select"],
151+
sub_layouts["dos-select"],
151152
]
152153
)
153154
]
154155
),
155156
Columns(
156157
[
157-
Column([Label("Summary"), self._sub_layouts["table"]]),
158-
Column([Label("Brillouin Zone"), self._sub_layouts["zone"]]),
158+
Column([Label("Summary"), sub_layouts["table"]]),
159+
Column([Label("Brillouin Zone"), sub_layouts["zone"]]),
159160
]
160161
),
161162
]
@@ -174,7 +175,7 @@ def _get_bs_dos(data):
174175
bandstructure_symm_line = data.get("bandstructure_symm_line")
175176
density_of_states = data.get("density_of_states")
176177

177-
if not mpid and (bandstructure_symm_line is None or density_of_states is None):
178+
if not mpid and bandstructure_symm_line is None and density_of_states is None:
178179
return None, None
179180

180181
if mpid:
@@ -615,99 +616,104 @@ def get_figure(
615616

616617
return go.Figure(layout=empty_plot_style)
617618

619+
# -- Add trace data to plots
620+
621+
traces = []
622+
xaxis_style = {}
623+
yaxis_style = {}
624+
xaxis_style_dos = {}
625+
yaxis_style_dos = {}
618626
if bs:
619627
bstraces, bs_data = BandstructureAndDosComponent.get_bandstructure_traces(
620628
bs, path_convention=path_convention, energy_window=energy_window
621629
)
630+
traces += bstraces
631+
632+
xaxis_style = dict(
633+
title=dict(text="Wave Vector", font=dict(size=16)),
634+
tickmode="array",
635+
tickvals=bs_data["ticks"]["distance"],
636+
ticktext=bs_data["ticks"]["label"],
637+
tickfont=dict(size=16),
638+
ticks="inside",
639+
tickwidth=2,
640+
showgrid=False,
641+
showline=True,
642+
zeroline=False,
643+
linewidth=2,
644+
mirror=True,
645+
range=[0, bs_data["ticks"]["distance"][-1]],
646+
linecolor="rgb(71,71,71)",
647+
gridcolor="white",
648+
)
649+
650+
yaxis_style = dict(
651+
title=dict(text="E−E<sub>fermi</sub> (eV)", font=dict(size=16)),
652+
tickfont=dict(size=16),
653+
showgrid=False,
654+
showline=True,
655+
zeroline=True,
656+
mirror="ticks",
657+
ticks="inside",
658+
linewidth=2,
659+
tickwidth=2,
660+
zerolinewidth=2,
661+
range=[-5, 9],
662+
linecolor="rgb(71,71,71)",
663+
gridcolor="white",
664+
zerolinecolor="white",
665+
)
622666

623667
if dos:
624668
dostraces = BandstructureAndDosComponent.get_dos_traces(
625669
dos, dos_select=dos_select, energy_window=energy_window
626670
)
671+
traces += dostraces
672+
673+
rmax = max(
674+
[
675+
max(dostraces[0]["x"]),
676+
abs(min(dostraces[0]["x"])),
677+
max(dostraces[1]["x"]),
678+
abs(min(dostraces[1]["x"])),
679+
]
680+
)
627681

628-
# TODO: add logic to handle if bstraces and/or dostraces not present
629-
630-
rmax = max(
631-
[
632-
max(dostraces[0]["x"]),
633-
abs(min(dostraces[0]["x"])),
634-
max(dostraces[1]["x"]),
635-
abs(min(dostraces[1]["x"])),
636-
]
637-
)
638-
639-
# -- Add trace data to plots
640-
641-
xaxis_style = dict(
642-
title=dict(text="Wave Vector", font=dict(size=16)),
643-
tickmode="array",
644-
tickvals=bs_data["ticks"]["distance"],
645-
ticktext=bs_data["ticks"]["label"],
646-
tickfont=dict(size=16),
647-
ticks="inside",
648-
tickwidth=2,
649-
showgrid=False,
650-
showline=True,
651-
zeroline=False,
652-
linewidth=2,
653-
mirror=True,
654-
range=[0, bs_data["ticks"]["distance"][-1]],
655-
linecolor="rgb(71,71,71)",
656-
gridcolor="white",
657-
)
658-
659-
yaxis_style = dict(
660-
title=dict(text="E−E<sub>fermi</sub> (eV)", font=dict(size=16)),
661-
tickfont=dict(size=16),
662-
showgrid=False,
663-
showline=True,
664-
zeroline=True,
665-
mirror="ticks",
666-
ticks="inside",
667-
linewidth=2,
668-
tickwidth=2,
669-
zerolinewidth=2,
670-
range=[-5, 9],
671-
linecolor="rgb(71,71,71)",
672-
gridcolor="white",
673-
zerolinecolor="white",
674-
)
675-
676-
xaxis_style_dos = dict(
677-
title=dict(text="Density of States", font=dict(size=16)),
678-
tickfont=dict(size=16),
679-
showgrid=False,
680-
showline=True,
681-
zeroline=False,
682-
mirror=True,
683-
ticks="inside",
684-
linewidth=2,
685-
tickwidth=2,
686-
range=[-rmax * 1.1 * int(len(bs_data["energy"].keys()) == 2), rmax * 1.1,],
687-
linecolor="rgb(71,71,71)",
688-
gridcolor="white",
689-
zerolinecolor="white",
690-
zerolinewidth=2,
691-
)
682+
xaxis_style_dos = dict(
683+
title=dict(text="Density of States", font=dict(size=16)),
684+
tickfont=dict(size=16),
685+
showgrid=False,
686+
showline=True,
687+
zeroline=False,
688+
mirror=True,
689+
ticks="inside",
690+
linewidth=2,
691+
tickwidth=2,
692+
range=[-rmax * 1.1 * int(len(dos.densities) == 2), rmax * 1.1,],
693+
linecolor="rgb(71,71,71)",
694+
gridcolor="white",
695+
zerolinecolor="white",
696+
zerolinewidth=2,
697+
)
692698

693-
yaxis_style_dos = dict(
694-
tickfont=dict(size=16),
695-
showgrid=False,
696-
showline=True,
697-
zeroline=True,
698-
showticklabels=False,
699-
mirror="ticks",
700-
ticks="inside",
701-
linewidth=2,
702-
tickwidth=2,
703-
zerolinewidth=2,
704-
range=[-5, 9],
705-
linecolor="rgb(71,71,71)",
706-
gridcolor="white",
707-
zerolinecolor="white",
708-
matches="y",
709-
anchor="x2",
710-
)
699+
yaxis_style_dos = dict(
700+
tickfont=dict(size=16),
701+
showgrid=False,
702+
showline=True,
703+
zeroline=True,
704+
showticklabels=False,
705+
mirror="ticks",
706+
ticks="inside",
707+
linewidth=2,
708+
tickwidth=2,
709+
zerolinewidth=2,
710+
range=[-5, 9],
711+
linecolor="rgb(71,71,71)",
712+
gridcolor="white",
713+
zerolinecolor="white",
714+
matches="y",
715+
anchor="x2",
716+
)
711717

712718
layout = dict(
713719
title="",
@@ -725,7 +731,7 @@ def get_figure(
725731
# clickmode="event+select"
726732
)
727733

728-
figure = {"data": bstraces + dostraces, "layout": layout}
734+
figure = {"data": traces, "layout": layout}
729735

730736
legend = dict(
731737
x=1.02,

0 commit comments

Comments
 (0)