@@ -541,11 +541,10 @@ def get_dos_traces(dos, dos_select, energy_window=(-6.0, 10.0)):
541
541
542
542
dostraces .append (trace_tdos )
543
543
544
- ele_dos = dos .get_element_dos ()
545
- elements = [str (entry ) for entry in ele_dos .keys ()]
546
-
547
- if dos_select == "ap" :
548
- proj_data = ele_dos
544
+ if dos_select == "tot" :
545
+ proj_data = {}
546
+ elif dos_select == "ap" :
547
+ proj_data = dos .get_element_dos ()
549
548
elif dos_select == "op" :
550
549
proj_data = dos .get_spd_dos ()
551
550
elif "orb" in dos_select :
@@ -670,14 +669,21 @@ def get_figure(
670
669
)
671
670
traces += dostraces
672
671
673
- rmax = max (
674
- [
675
- max (dostraces [0 ]["x" ]),
676
- abs (min (dostraces [0 ]["x" ])),
677
- max (dostraces [1 ]["x" ]),
678
- abs (min (dostraces [1 ]["x" ])),
679
- ]
680
- )
672
+ list_max = [
673
+ max (dostraces [0 ]["x" ]),
674
+ abs (min (dostraces [0 ]["x" ])),
675
+ ]
676
+
677
+ # check the max of the second dos trace only if spin polarized
678
+ spin_polarized = len (dos .densities .keys ()) == 2
679
+ if spin_polarized :
680
+ list_max .extend (
681
+ [
682
+ max (dostraces [1 ]["x" ]),
683
+ abs (min (dostraces [1 ]["x" ])),
684
+ ]
685
+ )
686
+ rmax = max (list_max )
681
687
682
688
xaxis_style_dos = dict (
683
689
title = dict (text = "Density of States" , font = dict (size = 16 )),
0 commit comments