@@ -158,24 +158,25 @@ def _sub_layouts(self):
158
158
}
159
159
160
160
def layout (self ):
161
+ sub_layouts = self ._sub_layouts
161
162
return html .Div (
162
163
[
163
- Columns ([Column ([self . _sub_layouts ["graph" ]])]),
164
+ Columns ([Column ([sub_layouts ["graph" ]])]),
164
165
Columns (
165
166
[
166
167
Column (
167
168
[
168
- self . _sub_layouts ["convention" ],
169
- self . _sub_layouts ["label-select" ],
170
- self . _sub_layouts ["dos-select" ],
169
+ sub_layouts ["convention" ],
170
+ sub_layouts ["label-select" ],
171
+ sub_layouts ["dos-select" ],
171
172
]
172
173
)
173
174
]
174
175
),
175
176
Columns (
176
177
[
177
- Column ([Label ("Summary" ), self . _sub_layouts ["table" ]]),
178
- Column ([Label ("Brillouin Zone" ), self . _sub_layouts ["zone" ]]),
178
+ Column ([Label ("Summary" ), sub_layouts ["table" ]]),
179
+ Column ([Label ("Brillouin Zone" ), sub_layouts ["zone" ]]),
179
180
]
180
181
),
181
182
]
@@ -194,7 +195,7 @@ def _get_bs_dos(data):
194
195
bandstructure_symm_line = data .get ("bandstructure_symm_line" )
195
196
density_of_states = data .get ("density_of_states" )
196
197
197
- if not mpid and ( bandstructure_symm_line is None or density_of_states is None ) :
198
+ if not mpid and bandstructure_symm_line is None and density_of_states is None :
198
199
return None , None
199
200
200
201
if mpid :
@@ -560,11 +561,10 @@ def get_dos_traces(dos, dos_select, energy_window=(-6.0, 10.0)):
560
561
561
562
dostraces .append (trace_tdos )
562
563
563
- ele_dos = dos .get_element_dos ()
564
- [str (entry ) for entry in ele_dos .keys ()]
565
-
566
- if dos_select == "ap" :
567
- proj_data = ele_dos
564
+ if dos_select == "tot" :
565
+ proj_data = {}
566
+ elif dos_select == "ap" :
567
+ proj_data = dos .get_element_dos ()
568
568
elif dos_select == "op" :
569
569
proj_data = dos .get_spd_dos ()
570
570
elif "orb" in dos_select :
@@ -635,102 +635,111 @@ def get_figure(
635
635
636
636
return go .Figure (layout = empty_plot_style )
637
637
638
+ # -- Add trace data to plots
639
+
640
+ traces = []
641
+ xaxis_style = {}
642
+ yaxis_style = {}
643
+ xaxis_style_dos = {}
644
+ yaxis_style_dos = {}
638
645
if bs :
639
646
bstraces , bs_data = BandstructureAndDosComponent .get_bandstructure_traces (
640
647
bs , path_convention = path_convention , energy_window = energy_window
641
648
)
649
+ traces += bstraces
650
+
651
+ xaxis_style = dict (
652
+ title = dict (text = "Wave Vector" , font = dict (size = 16 )),
653
+ tickmode = "array" ,
654
+ tickvals = bs_data ["ticks" ]["distance" ],
655
+ ticktext = bs_data ["ticks" ]["label" ],
656
+ tickfont = dict (size = 16 ),
657
+ ticks = "inside" ,
658
+ tickwidth = 2 ,
659
+ showgrid = False ,
660
+ showline = True ,
661
+ zeroline = False ,
662
+ linewidth = 2 ,
663
+ mirror = True ,
664
+ range = [0 , bs_data ["ticks" ]["distance" ][- 1 ]],
665
+ linecolor = "rgb(71,71,71)" ,
666
+ gridcolor = "white" ,
667
+ )
668
+
669
+ yaxis_style = dict (
670
+ title = dict (text = "E−E<sub>fermi</sub> (eV)" , font = dict (size = 16 )),
671
+ tickfont = dict (size = 16 ),
672
+ showgrid = False ,
673
+ showline = True ,
674
+ zeroline = True ,
675
+ mirror = "ticks" ,
676
+ ticks = "inside" ,
677
+ linewidth = 2 ,
678
+ tickwidth = 2 ,
679
+ zerolinewidth = 2 ,
680
+ range = [- 5 , 9 ],
681
+ linecolor = "rgb(71,71,71)" ,
682
+ gridcolor = "white" ,
683
+ zerolinecolor = "white" ,
684
+ )
642
685
643
686
if dos :
644
687
dostraces = BandstructureAndDosComponent .get_dos_traces (
645
688
dos , dos_select = dos_select , energy_window = energy_window
646
689
)
690
+ traces += dostraces
647
691
648
- # TODO: add logic to handle if bstraces and/or dostraces not present
649
-
650
- rmax = max (
651
- [
692
+ list_max = [
652
693
max (dostraces [0 ]["x" ]),
653
694
abs (min (dostraces [0 ]["x" ])),
654
- max (dostraces [1 ]["x" ]),
655
- abs (min (dostraces [1 ]["x" ])),
656
695
]
657
- )
658
-
659
- # -- Add trace data to plots
660
-
661
- xaxis_style = dict (
662
- title = dict (text = "Wave Vector" , font = dict (size = 16 )),
663
- tickmode = "array" ,
664
- tickvals = bs_data ["ticks" ]["distance" ],
665
- ticktext = bs_data ["ticks" ]["label" ],
666
- tickfont = dict (size = 16 ),
667
- ticks = "inside" ,
668
- tickwidth = 2 ,
669
- showgrid = False ,
670
- showline = True ,
671
- zeroline = False ,
672
- linewidth = 2 ,
673
- mirror = True ,
674
- range = [0 , bs_data ["ticks" ]["distance" ][- 1 ]],
675
- linecolor = "rgb(71,71,71)" ,
676
- gridcolor = "white" ,
677
- )
678
-
679
- yaxis_style = dict (
680
- title = dict (text = "E−E<sub>fermi</sub> (eV)" , font = dict (size = 16 )),
681
- tickfont = dict (size = 16 ),
682
- showgrid = False ,
683
- showline = True ,
684
- zeroline = True ,
685
- mirror = "ticks" ,
686
- ticks = "inside" ,
687
- linewidth = 2 ,
688
- tickwidth = 2 ,
689
- zerolinewidth = 2 ,
690
- range = [- 5 , 9 ],
691
- linecolor = "rgb(71,71,71)" ,
692
- gridcolor = "white" ,
693
- zerolinecolor = "white" ,
694
- )
695
696
696
- xaxis_style_dos = dict (
697
- title = dict (text = "Density of States" , font = dict (size = 16 )),
698
- tickfont = dict (size = 16 ),
699
- showgrid = False ,
700
- showline = True ,
701
- zeroline = False ,
702
- mirror = True ,
703
- ticks = "inside" ,
704
- linewidth = 2 ,
705
- tickwidth = 2 ,
706
- range = [
707
- - rmax * 1.1 * int (len (bs_data ["energy" ].keys ()) == 2 ),
708
- rmax * 1.1 ,
709
- ],
710
- linecolor = "rgb(71,71,71)" ,
711
- gridcolor = "white" ,
712
- zerolinecolor = "white" ,
713
- zerolinewidth = 2 ,
714
- )
697
+ # check the max of the second dos trace only if spin polarized
698
+ spin_polarized = len (dos .densities .keys ()) == 2
699
+ if spin_polarized :
700
+ list_max .extend (
701
+ [
702
+ max (dostraces [1 ]["x" ]),
703
+ abs (min (dostraces [1 ]["x" ])),
704
+ ]
705
+ )
706
+ rmax = max (list_max )
707
+
708
+ xaxis_style_dos = dict (
709
+ title = dict (text = "Density of States" , font = dict (size = 16 )),
710
+ tickfont = dict (size = 16 ),
711
+ showgrid = False ,
712
+ showline = True ,
713
+ zeroline = False ,
714
+ mirror = True ,
715
+ ticks = "inside" ,
716
+ linewidth = 2 ,
717
+ tickwidth = 2 ,
718
+ range = [- rmax * 1.1 * int (len (dos .densities ) == 2 ), rmax * 1.1 ,],
719
+ linecolor = "rgb(71,71,71)" ,
720
+ gridcolor = "white" ,
721
+ zerolinecolor = "white" ,
722
+ zerolinewidth = 2 ,
723
+ )
715
724
716
- yaxis_style_dos = dict (
717
- tickfont = dict (size = 16 ),
718
- showgrid = False ,
719
- showline = True ,
720
- zeroline = True ,
721
- showticklabels = False ,
722
- mirror = "ticks" ,
723
- ticks = "inside" ,
724
- linewidth = 2 ,
725
- tickwidth = 2 ,
726
- zerolinewidth = 2 ,
727
- range = [- 5 , 9 ],
728
- linecolor = "rgb(71,71,71)" ,
729
- gridcolor = "white" ,
730
- zerolinecolor = "white" ,
731
- matches = "y" ,
732
- anchor = "x2" ,
733
- )
725
+ yaxis_style_dos = dict (
726
+ tickfont = dict (size = 16 ),
727
+ showgrid = False ,
728
+ showline = True ,
729
+ zeroline = True ,
730
+ showticklabels = False ,
731
+ mirror = "ticks" ,
732
+ ticks = "inside" ,
733
+ linewidth = 2 ,
734
+ tickwidth = 2 ,
735
+ zerolinewidth = 2 ,
736
+ range = [- 5 , 9 ],
737
+ linecolor = "rgb(71,71,71)" ,
738
+ gridcolor = "white" ,
739
+ zerolinecolor = "white" ,
740
+ matches = "y" ,
741
+ anchor = "x2" ,
742
+ )
734
743
735
744
layout = dict (
736
745
title = "" ,
@@ -748,7 +757,7 @@ def get_figure(
748
757
# clickmode="event+select"
749
758
)
750
759
751
- figure = {"data" : bstraces + dostraces , "layout" : layout }
760
+ figure = {"data" : traces , "layout" : layout }
752
761
753
762
legend = dict (
754
763
x = 1.02 ,
0 commit comments