Skip to content

Commit 616608f

Browse files
authored
Merge pull request #248 from gpetretto/devel
Updates to BS and XRD components
2 parents 7098073 + ce12a66 commit 616608f

File tree

2 files changed

+199
-177
lines changed

2 files changed

+199
-177
lines changed

crystal_toolkit/components/bandstructure.py

Lines changed: 103 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -158,24 +158,25 @@ def _sub_layouts(self):
158158
}
159159

160160
def layout(self):
161+
sub_layouts = self._sub_layouts
161162
return html.Div(
162163
[
163-
Columns([Column([self._sub_layouts["graph"]])]),
164+
Columns([Column([sub_layouts["graph"]])]),
164165
Columns(
165166
[
166167
Column(
167168
[
168-
self._sub_layouts["convention"],
169-
self._sub_layouts["label-select"],
170-
self._sub_layouts["dos-select"],
169+
sub_layouts["convention"],
170+
sub_layouts["label-select"],
171+
sub_layouts["dos-select"],
171172
]
172173
)
173174
]
174175
),
175176
Columns(
176177
[
177-
Column([Label("Summary"), self._sub_layouts["table"]]),
178-
Column([Label("Brillouin Zone"), self._sub_layouts["zone"]]),
178+
Column([Label("Summary"), sub_layouts["table"]]),
179+
Column([Label("Brillouin Zone"), sub_layouts["zone"]]),
179180
]
180181
),
181182
]
@@ -194,7 +195,7 @@ def _get_bs_dos(data):
194195
bandstructure_symm_line = data.get("bandstructure_symm_line")
195196
density_of_states = data.get("density_of_states")
196197

197-
if not mpid and (bandstructure_symm_line is None or density_of_states is None):
198+
if not mpid and bandstructure_symm_line is None and density_of_states is None:
198199
return None, None
199200

200201
if mpid:
@@ -560,11 +561,10 @@ def get_dos_traces(dos, dos_select, energy_window=(-6.0, 10.0)):
560561

561562
dostraces.append(trace_tdos)
562563

563-
ele_dos = dos.get_element_dos()
564-
[str(entry) for entry in ele_dos.keys()]
565-
566-
if dos_select == "ap":
567-
proj_data = ele_dos
564+
if dos_select == "tot":
565+
proj_data = {}
566+
elif dos_select == "ap":
567+
proj_data = dos.get_element_dos()
568568
elif dos_select == "op":
569569
proj_data = dos.get_spd_dos()
570570
elif "orb" in dos_select:
@@ -635,102 +635,111 @@ def get_figure(
635635

636636
return go.Figure(layout=empty_plot_style)
637637

638+
# -- Add trace data to plots
639+
640+
traces = []
641+
xaxis_style = {}
642+
yaxis_style = {}
643+
xaxis_style_dos = {}
644+
yaxis_style_dos = {}
638645
if bs:
639646
bstraces, bs_data = BandstructureAndDosComponent.get_bandstructure_traces(
640647
bs, path_convention=path_convention, energy_window=energy_window
641648
)
649+
traces += bstraces
650+
651+
xaxis_style = dict(
652+
title=dict(text="Wave Vector", font=dict(size=16)),
653+
tickmode="array",
654+
tickvals=bs_data["ticks"]["distance"],
655+
ticktext=bs_data["ticks"]["label"],
656+
tickfont=dict(size=16),
657+
ticks="inside",
658+
tickwidth=2,
659+
showgrid=False,
660+
showline=True,
661+
zeroline=False,
662+
linewidth=2,
663+
mirror=True,
664+
range=[0, bs_data["ticks"]["distance"][-1]],
665+
linecolor="rgb(71,71,71)",
666+
gridcolor="white",
667+
)
668+
669+
yaxis_style = dict(
670+
title=dict(text="E−E<sub>fermi</sub> (eV)", font=dict(size=16)),
671+
tickfont=dict(size=16),
672+
showgrid=False,
673+
showline=True,
674+
zeroline=True,
675+
mirror="ticks",
676+
ticks="inside",
677+
linewidth=2,
678+
tickwidth=2,
679+
zerolinewidth=2,
680+
range=[-5, 9],
681+
linecolor="rgb(71,71,71)",
682+
gridcolor="white",
683+
zerolinecolor="white",
684+
)
642685

643686
if dos:
644687
dostraces = BandstructureAndDosComponent.get_dos_traces(
645688
dos, dos_select=dos_select, energy_window=energy_window
646689
)
690+
traces += dostraces
647691

648-
# TODO: add logic to handle if bstraces and/or dostraces not present
649-
650-
rmax = max(
651-
[
692+
list_max = [
652693
max(dostraces[0]["x"]),
653694
abs(min(dostraces[0]["x"])),
654-
max(dostraces[1]["x"]),
655-
abs(min(dostraces[1]["x"])),
656695
]
657-
)
658-
659-
# -- Add trace data to plots
660-
661-
xaxis_style = dict(
662-
title=dict(text="Wave Vector", font=dict(size=16)),
663-
tickmode="array",
664-
tickvals=bs_data["ticks"]["distance"],
665-
ticktext=bs_data["ticks"]["label"],
666-
tickfont=dict(size=16),
667-
ticks="inside",
668-
tickwidth=2,
669-
showgrid=False,
670-
showline=True,
671-
zeroline=False,
672-
linewidth=2,
673-
mirror=True,
674-
range=[0, bs_data["ticks"]["distance"][-1]],
675-
linecolor="rgb(71,71,71)",
676-
gridcolor="white",
677-
)
678-
679-
yaxis_style = dict(
680-
title=dict(text="E−E<sub>fermi</sub> (eV)", font=dict(size=16)),
681-
tickfont=dict(size=16),
682-
showgrid=False,
683-
showline=True,
684-
zeroline=True,
685-
mirror="ticks",
686-
ticks="inside",
687-
linewidth=2,
688-
tickwidth=2,
689-
zerolinewidth=2,
690-
range=[-5, 9],
691-
linecolor="rgb(71,71,71)",
692-
gridcolor="white",
693-
zerolinecolor="white",
694-
)
695696

696-
xaxis_style_dos = dict(
697-
title=dict(text="Density of States", font=dict(size=16)),
698-
tickfont=dict(size=16),
699-
showgrid=False,
700-
showline=True,
701-
zeroline=False,
702-
mirror=True,
703-
ticks="inside",
704-
linewidth=2,
705-
tickwidth=2,
706-
range=[
707-
-rmax * 1.1 * int(len(bs_data["energy"].keys()) == 2),
708-
rmax * 1.1,
709-
],
710-
linecolor="rgb(71,71,71)",
711-
gridcolor="white",
712-
zerolinecolor="white",
713-
zerolinewidth=2,
714-
)
697+
# check the max of the second dos trace only if spin polarized
698+
spin_polarized = len(dos.densities.keys()) == 2
699+
if spin_polarized:
700+
list_max.extend(
701+
[
702+
max(dostraces[1]["x"]),
703+
abs(min(dostraces[1]["x"])),
704+
]
705+
)
706+
rmax = max(list_max)
707+
708+
xaxis_style_dos = dict(
709+
title=dict(text="Density of States", font=dict(size=16)),
710+
tickfont=dict(size=16),
711+
showgrid=False,
712+
showline=True,
713+
zeroline=False,
714+
mirror=True,
715+
ticks="inside",
716+
linewidth=2,
717+
tickwidth=2,
718+
range=[-rmax * 1.1 * int(len(dos.densities) == 2), rmax * 1.1,],
719+
linecolor="rgb(71,71,71)",
720+
gridcolor="white",
721+
zerolinecolor="white",
722+
zerolinewidth=2,
723+
)
715724

716-
yaxis_style_dos = dict(
717-
tickfont=dict(size=16),
718-
showgrid=False,
719-
showline=True,
720-
zeroline=True,
721-
showticklabels=False,
722-
mirror="ticks",
723-
ticks="inside",
724-
linewidth=2,
725-
tickwidth=2,
726-
zerolinewidth=2,
727-
range=[-5, 9],
728-
linecolor="rgb(71,71,71)",
729-
gridcolor="white",
730-
zerolinecolor="white",
731-
matches="y",
732-
anchor="x2",
733-
)
725+
yaxis_style_dos = dict(
726+
tickfont=dict(size=16),
727+
showgrid=False,
728+
showline=True,
729+
zeroline=True,
730+
showticklabels=False,
731+
mirror="ticks",
732+
ticks="inside",
733+
linewidth=2,
734+
tickwidth=2,
735+
zerolinewidth=2,
736+
range=[-5, 9],
737+
linecolor="rgb(71,71,71)",
738+
gridcolor="white",
739+
zerolinecolor="white",
740+
matches="y",
741+
anchor="x2",
742+
)
734743

735744
layout = dict(
736745
title="",
@@ -748,7 +757,7 @@ def get_figure(
748757
# clickmode="event+select"
749758
)
750759

751-
figure = {"data": bstraces + dostraces, "layout": layout}
760+
figure = {"data": traces, "layout": layout}
752761

753762
legend = dict(
754763
x=1.02,

0 commit comments

Comments
 (0)