@@ -138,24 +138,25 @@ def _sub_layouts(self):
138
138
}
139
139
140
140
def layout (self ):
141
+ sub_layouts = self ._sub_layouts
141
142
return html .Div (
142
143
[
143
- Columns ([Column ([self . _sub_layouts ["graph" ]])]),
144
+ Columns ([Column ([sub_layouts ["graph" ]])]),
144
145
Columns (
145
146
[
146
147
Column (
147
148
[
148
- self . _sub_layouts ["convention" ],
149
- self . _sub_layouts ["label-select" ],
150
- self . _sub_layouts ["dos-select" ],
149
+ sub_layouts ["convention" ],
150
+ sub_layouts ["label-select" ],
151
+ sub_layouts ["dos-select" ],
151
152
]
152
153
)
153
154
]
154
155
),
155
156
Columns (
156
157
[
157
- Column ([Label ("Summary" ), self . _sub_layouts ["table" ]]),
158
- Column ([Label ("Brillouin Zone" ), self . _sub_layouts ["zone" ]]),
158
+ Column ([Label ("Summary" ), sub_layouts ["table" ]]),
159
+ Column ([Label ("Brillouin Zone" ), sub_layouts ["zone" ]]),
159
160
]
160
161
),
161
162
]
@@ -174,7 +175,7 @@ def _get_bs_dos(data):
174
175
bandstructure_symm_line = data .get ("bandstructure_symm_line" )
175
176
density_of_states = data .get ("density_of_states" )
176
177
177
- if not mpid and ( bandstructure_symm_line is None or density_of_states is None ) :
178
+ if not mpid and bandstructure_symm_line is None and density_of_states is None :
178
179
return None , None
179
180
180
181
if mpid :
@@ -615,99 +616,104 @@ def get_figure(
615
616
616
617
return go .Figure (layout = empty_plot_style )
617
618
619
+ # -- Add trace data to plots
620
+
621
+ traces = []
622
+ xaxis_style = {}
623
+ yaxis_style = {}
624
+ xaxis_style_dos = {}
625
+ yaxis_style_dos = {}
618
626
if bs :
619
627
bstraces , bs_data = BandstructureAndDosComponent .get_bandstructure_traces (
620
628
bs , path_convention = path_convention , energy_window = energy_window
621
629
)
630
+ traces += bstraces
631
+
632
+ xaxis_style = dict (
633
+ title = dict (text = "Wave Vector" , font = dict (size = 16 )),
634
+ tickmode = "array" ,
635
+ tickvals = bs_data ["ticks" ]["distance" ],
636
+ ticktext = bs_data ["ticks" ]["label" ],
637
+ tickfont = dict (size = 16 ),
638
+ ticks = "inside" ,
639
+ tickwidth = 2 ,
640
+ showgrid = False ,
641
+ showline = True ,
642
+ zeroline = False ,
643
+ linewidth = 2 ,
644
+ mirror = True ,
645
+ range = [0 , bs_data ["ticks" ]["distance" ][- 1 ]],
646
+ linecolor = "rgb(71,71,71)" ,
647
+ gridcolor = "white" ,
648
+ )
649
+
650
+ yaxis_style = dict (
651
+ title = dict (text = "E−E<sub>fermi</sub> (eV)" , font = dict (size = 16 )),
652
+ tickfont = dict (size = 16 ),
653
+ showgrid = False ,
654
+ showline = True ,
655
+ zeroline = True ,
656
+ mirror = "ticks" ,
657
+ ticks = "inside" ,
658
+ linewidth = 2 ,
659
+ tickwidth = 2 ,
660
+ zerolinewidth = 2 ,
661
+ range = [- 5 , 9 ],
662
+ linecolor = "rgb(71,71,71)" ,
663
+ gridcolor = "white" ,
664
+ zerolinecolor = "white" ,
665
+ )
622
666
623
667
if dos :
624
668
dostraces = BandstructureAndDosComponent .get_dos_traces (
625
669
dos , dos_select = dos_select , energy_window = energy_window
626
670
)
671
+ traces += dostraces
672
+
673
+ rmax = max (
674
+ [
675
+ max (dostraces [0 ]["x" ]),
676
+ abs (min (dostraces [0 ]["x" ])),
677
+ max (dostraces [1 ]["x" ]),
678
+ abs (min (dostraces [1 ]["x" ])),
679
+ ]
680
+ )
627
681
628
- # TODO: add logic to handle if bstraces and/or dostraces not present
629
-
630
- rmax = max (
631
- [
632
- max (dostraces [0 ]["x" ]),
633
- abs (min (dostraces [0 ]["x" ])),
634
- max (dostraces [1 ]["x" ]),
635
- abs (min (dostraces [1 ]["x" ])),
636
- ]
637
- )
638
-
639
- # -- Add trace data to plots
640
-
641
- xaxis_style = dict (
642
- title = dict (text = "Wave Vector" , font = dict (size = 16 )),
643
- tickmode = "array" ,
644
- tickvals = bs_data ["ticks" ]["distance" ],
645
- ticktext = bs_data ["ticks" ]["label" ],
646
- tickfont = dict (size = 16 ),
647
- ticks = "inside" ,
648
- tickwidth = 2 ,
649
- showgrid = False ,
650
- showline = True ,
651
- zeroline = False ,
652
- linewidth = 2 ,
653
- mirror = True ,
654
- range = [0 , bs_data ["ticks" ]["distance" ][- 1 ]],
655
- linecolor = "rgb(71,71,71)" ,
656
- gridcolor = "white" ,
657
- )
658
-
659
- yaxis_style = dict (
660
- title = dict (text = "E−E<sub>fermi</sub> (eV)" , font = dict (size = 16 )),
661
- tickfont = dict (size = 16 ),
662
- showgrid = False ,
663
- showline = True ,
664
- zeroline = True ,
665
- mirror = "ticks" ,
666
- ticks = "inside" ,
667
- linewidth = 2 ,
668
- tickwidth = 2 ,
669
- zerolinewidth = 2 ,
670
- range = [- 5 , 9 ],
671
- linecolor = "rgb(71,71,71)" ,
672
- gridcolor = "white" ,
673
- zerolinecolor = "white" ,
674
- )
675
-
676
- xaxis_style_dos = dict (
677
- title = dict (text = "Density of States" , font = dict (size = 16 )),
678
- tickfont = dict (size = 16 ),
679
- showgrid = False ,
680
- showline = True ,
681
- zeroline = False ,
682
- mirror = True ,
683
- ticks = "inside" ,
684
- linewidth = 2 ,
685
- tickwidth = 2 ,
686
- range = [- rmax * 1.1 * int (len (bs_data ["energy" ].keys ()) == 2 ), rmax * 1.1 ,],
687
- linecolor = "rgb(71,71,71)" ,
688
- gridcolor = "white" ,
689
- zerolinecolor = "white" ,
690
- zerolinewidth = 2 ,
691
- )
682
+ xaxis_style_dos = dict (
683
+ title = dict (text = "Density of States" , font = dict (size = 16 )),
684
+ tickfont = dict (size = 16 ),
685
+ showgrid = False ,
686
+ showline = True ,
687
+ zeroline = False ,
688
+ mirror = True ,
689
+ ticks = "inside" ,
690
+ linewidth = 2 ,
691
+ tickwidth = 2 ,
692
+ range = [- rmax * 1.1 * int (len (dos .densities ) == 2 ), rmax * 1.1 ,],
693
+ linecolor = "rgb(71,71,71)" ,
694
+ gridcolor = "white" ,
695
+ zerolinecolor = "white" ,
696
+ zerolinewidth = 2 ,
697
+ )
692
698
693
- yaxis_style_dos = dict (
694
- tickfont = dict (size = 16 ),
695
- showgrid = False ,
696
- showline = True ,
697
- zeroline = True ,
698
- showticklabels = False ,
699
- mirror = "ticks" ,
700
- ticks = "inside" ,
701
- linewidth = 2 ,
702
- tickwidth = 2 ,
703
- zerolinewidth = 2 ,
704
- range = [- 5 , 9 ],
705
- linecolor = "rgb(71,71,71)" ,
706
- gridcolor = "white" ,
707
- zerolinecolor = "white" ,
708
- matches = "y" ,
709
- anchor = "x2" ,
710
- )
699
+ yaxis_style_dos = dict (
700
+ tickfont = dict (size = 16 ),
701
+ showgrid = False ,
702
+ showline = True ,
703
+ zeroline = True ,
704
+ showticklabels = False ,
705
+ mirror = "ticks" ,
706
+ ticks = "inside" ,
707
+ linewidth = 2 ,
708
+ tickwidth = 2 ,
709
+ zerolinewidth = 2 ,
710
+ range = [- 5 , 9 ],
711
+ linecolor = "rgb(71,71,71)" ,
712
+ gridcolor = "white" ,
713
+ zerolinecolor = "white" ,
714
+ matches = "y" ,
715
+ anchor = "x2" ,
716
+ )
711
717
712
718
layout = dict (
713
719
title = "" ,
@@ -725,7 +731,7 @@ def get_figure(
725
731
# clickmode="event+select"
726
732
)
727
733
728
- figure = {"data" : bstraces + dostraces , "layout" : layout }
734
+ figure = {"data" : traces , "layout" : layout }
729
735
730
736
legend = dict (
731
737
x = 1.02 ,
0 commit comments