11'''
2- Date: 2021-05-08 11:47:09
2+ Date: 2021-12-29 10:27:01
33LastEditors: jiyuyang
4- LastEditTime: 2021-08-26 12:07:22
4+ LastEditTime: 2022-01-03 17:06:14
5566'''
77
88from collections import OrderedDict , namedtuple
9+ import numpy as np
910from os import PathLike
1011from typing import Sequence , Tuple
11-
12- import matplotlib .pyplot as plt
13- import numpy as np
12+ from matplotlib .figure import Figure
1413from matplotlib import axes
1514
1615from abacus_plot .utils import energy_minus_efermi , list_elem2str , read_kpt
1918class BandPlot :
2019 """Plot band structure"""
2120
21+ def __init__ (self , fig : Figure , ax : axes .Axes , ** kwargs ) -> None :
22+ self .fig = fig
23+ self .ax = ax
24+ self ._lw = kwargs .pop ('lw' , 2 )
25+ self ._bwidth = kwargs .pop ('bwdith' , 3 )
26+ self ._label = kwargs .pop ('label' , None )
27+ self ._color = kwargs .pop ('color' , None )
28+ self ._linestyle = kwargs .pop ('linestyle' , 'solid' )
29+ self .plot_params = kwargs
30+
2231 @classmethod
2332 def set_vcband (cls , energy : Sequence ) -> Tuple [namedtuple , namedtuple ]:
2433 """Separate valence and conduct band
@@ -57,13 +66,12 @@ def read(cls, filename: PathLike) -> Tuple[np.ndarray, np.ndarray]:
5766 data = np .loadtxt (filename )
5867 X , y = np .split (data , (1 , ), axis = 1 )
5968 x = X .flatten ()
69+
6070 return x , y
6171
62- @classmethod
63- def _set_figure (cls , ax : axes .Axes , index : dict , range : Sequence ):
72+ def _set_figure (self , index : dict , range : Sequence ):
6473 """set figure and axes for plotting
6574
66- :params ax: matplotlib.axes.Axes object
6775 :params index: dict of label of points of x-axis and its index in data file. Range of x-axis based on index.value()
6876 :params range: range of y-axis
6977 """
@@ -79,131 +87,157 @@ def _set_figure(cls, ax: axes.Axes, index: dict, range: Sequence):
7987 values .append (t )
8088
8189 # x-axis
82- ax .set_xticks (values )
83- ax .set_xticklabels (keys )
84- ax .set_xlim (values [0 ], values [- 1 ])
85- ax .set_xlabel ("Wave Vector" )
90+ self .ax .set_xticks (values )
91+ self .ax .set_xticklabels (keys )
92+ self .ax .set_xlim (values [0 ], values [- 1 ])
93+ if "xlabel_params" in self .plot_params .keys ():
94+ self .ax .set_xlabel (
95+ "Wave Vector" , ** self .plot_params ["xlabel_params" ])
96+ else :
97+ self .ax .set_xlabel ("Wave Vector" , size = 25 )
8698
8799 # y-axis
88100 if range :
89- ax .set_ylim (range [0 ], range [1 ])
90- ax .set_ylabel (r"$E-E_{fermi}(eV)$" )
101+ self .ax .set_ylim (range [0 ], range [1 ])
102+ if "ylabel_params" in self .plot_params .keys ():
103+ self .ax .set_ylabel (
104+ "Energy(eV)" , ** self .plot_params ["ylabel_params" ])
105+ else :
106+ self .ax .set_ylabel ("Energy(eV)" , size = 25 )
107+
108+ # notes
109+ if "notes" in self .plot_params .keys ():
110+ from matplotlib .offsetbox import AnchoredText
111+ if "s" in self .plot_params ["notes" ].keys () and len (self .plot_params ["notes" ].keys ()) == 1 :
112+ self .ax .add_artist (AnchoredText (self .plot_params ["notes" ]["s" ], loc = 'upper left' , prop = dict (size = 25 ),
113+ borderpad = 0.2 , frameon = False ))
114+ else :
115+ self .ax .add_artist (AnchoredText (** self .plot_params ["notes" ]))
91116
92- # others
93- ax .grid (axis = 'x' , lw = 1.2 )
94- ax .axhline (0 , linestyle = "--" , c = 'b' , lw = 1.0 )
95- handles , labels = ax .get_legend_handles_labels ()
96- by_label = OrderedDict (zip (labels , handles ))
97- ax .legend (by_label .values (), by_label .keys ())
117+ # ticks
118+ if "tick_params" in self .plot_params .keys ():
119+ self .ax .tick_params (** self .plot_params ["tick_params" ])
120+ else :
121+ self .ax .tick_params (labelsize = 25 )
122+
123+ # frame
124+ bwidth = self ._bwidth
125+ self .ax .spines ['top' ].set_linewidth (bwidth )
126+ self .ax .spines ['right' ].set_linewidth (bwidth )
127+ self .ax .spines ['left' ].set_linewidth (bwidth )
128+ self .ax .spines ['bottom' ].set_linewidth (bwidth )
129+
130+ # guides
131+ if "grid_params" in self .plot_params .keys ():
132+ self .ax .grid (axis = 'x' , ** self .plot_params ["grid_params" ])
133+ else :
134+ self .ax .grid (axis = 'x' , lw = 1.2 )
135+ if "hline_params" in self .plot_params .keys ():
136+ self .ax .axhline (0 , ** self .plot_params ["hline_params" ])
137+ else :
138+ self .ax .axhline (0 , linestyle = "--" , c = 'b' , lw = 1.0 )
139+
140+ if self ._label :
141+ handles , labels = self .ax .get_legend_handles_labels ()
142+ by_label = OrderedDict (zip (labels , handles ))
143+ if "legend_prop" in self .plot_params .keys ():
144+ self .ax .legend (by_label .values (), by_label .keys (),
145+ prop = self .plot_params ["legend_prop" ])
146+ else :
147+ self .ax .legend (by_label .values (),
148+ by_label .keys (), prop = {'size' : 15 })
98149
99- @classmethod
100- def plot (cls , x : Sequence , y : Sequence , index : Sequence , efermi : float = 0 , energy_range : Sequence [float ] = [], label : str = None , color : str = None , outfile : PathLike = 'band.png' ):
150+ def plot (self , x : Sequence , y : Sequence , index : Sequence , efermi : float = 0 , energy_range : Sequence [float ] = []):
101151 """Plot band structure
102152
103153 :params x, y: x-axis and y-axis coordinates
104154 :params index: special k-points label and its index in data file
105155 :params efermi: Fermi level in unit eV
106156 :params energy_range: range of energy to plot, its length equals to two
107- :params label: band label. Default: ''
108- :params color: band color. Default: 'black'
109- :params outfile: band picture file name. Default: 'band.png'
110157 """
111158
112- fig , ax = plt .subplots ()
113-
114- if not color :
115- color = 'black'
159+ if not self ._color :
160+ self ._color = 'black'
116161
117162 kpoints , energy = x , y
118163 energy = energy_minus_efermi (energy , efermi )
119164
120- ax .plot (kpoints , energy , lw = 0.8 , color = color , label = label )
121- cls ._set_figure (ax , index , energy_range )
165+ self .ax .plot (kpoints , energy , lw = self ._lw , color = self ._color ,
166+ label = self ._label , linestyle = self ._linestyle )
167+ self ._set_figure (index , energy_range )
122168
123- plt .savefig (outfile )
124-
125- @classmethod
126- def singleplot (cls , datafile : PathLike , kptfile : str = [], efermi : float = 0 , energy_range : Sequence [float ] = [], shift : bool = False , label : str = None , color : str = None , outfile : PathLike = 'band.png' ):
169+ def singleplot (self , datafile : PathLike , kptfile : str = '' , efermi : float = 0 , energy_range : Sequence [float ] = [], shift : bool = False ):
127170 """Plot band structure using data file
128171
129172 :params datafile: string of band date file
130173 :params kptfile: k-point file
131174 :params efermi: Fermi level in unit eV
132175 :params energy_range: range of energy to plot, its length equals to two
133176 :params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False
134- :params label: band label. Default: ''
135- :params color: band color. Default: 'black'
136- :params outfile: band picture file name. Default: 'band.png'
137177 """
138178
139- fig , ax = plt .subplots ()
140179 kpt = read_kpt (kptfile )
141180
142- if not color :
143- color = 'black'
181+ if not self . _color :
182+ self . _color = 'black'
144183
145- kpoints , energy = cls .read (datafile )
184+ kpoints , energy = self .read (datafile )
185+ energy = energy_minus_efermi (energy , efermi )
146186 if shift :
147- vb , cb = cls .set_vcband (energy_minus_efermi ( energy , efermi ) )
148- ax .plot (kpoints , np .vstack ((vb .band , cb .band )).T ,
149- lw = 0.8 , color = color , label = label )
150- cls .info (kpt .full_kpath , vb , cb )
187+ vb , cb = self .set_vcband (energy )
188+ self . ax .plot (kpoints , np .vstack ((vb .band , cb .band )).T ,
189+ lw = self . _lw , color = self . _color , label = self . _label , linestyle = self . _linestyle )
190+ self .info (kpt .full_kpath , vb , cb )
151191 else :
152- ax .plot (kpoints , energy_minus_efermi ( energy , efermi ) ,
153- lw = 0.8 , color = color , label = label )
192+ self . ax .plot (kpoints , energy ,
193+ lw = self . _lw , color = self . _color , label = self . _label , linestyle = self . _linestyle )
154194 index = kpt .label_special_k
155- cls ._set_figure (ax , index , energy_range )
195+ self ._set_figure (index , energy_range )
156196
157- plt .savefig (outfile )
158-
159- @classmethod
160- def multiplot (cls , datafile : Sequence [PathLike ], kptfile : str = '' , efermi : Sequence [float ] = [], energy_range : Sequence [float ] = [], shift : bool = True , label : Sequence [str ] = None , color : Sequence [str ] = None , outfile : PathLike = 'band.png' ):
197+ def multiplot (self , datafile : Sequence [PathLike ], kptfile : str = '' , efermi : Sequence [float ] = [], energy_range : Sequence [float ] = [], shift : bool = True ):
161198 """Plot more than two band structures using data file
162199
163200 :params datafile: list of path of band date file
164201 :params kptfile: k-point file
165202 :params efermi: list of Fermi levels in unit eV, its length equals to `filename`
166203 :params energy_range: range of energy to plot, its length equals to two
167204 :params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False
168- :params label: list of band labels, its length equals to `filename`
169- :params color: list of band colors, its length equals to `filename`
170- :params outfile: band picture file name. Default: 'band.png'
171205 """
172206
173- fig , ax = plt .subplots ()
174207 kpt = read_kpt (kptfile )
175208
176209 if not efermi :
177210 efermi = [0.0 for i in range (len (datafile ))]
178- if not label :
179- label = ['' for i in range (len (datafile ))]
180- if not color :
181- color = ['black' for i in range (len (datafile ))]
211+ if not self ._label :
212+ self ._label = ['' for i in range (len (datafile ))]
213+ if not self ._color :
214+ self ._color = ['black' for i in range (len (datafile ))]
215+ if not self ._linestyle :
216+ self ._linestyle = ['solid' for i in range (len (datafile ))]
182217
183218 emin = - np .inf
184219 emax = np .inf
185220 for i , file in enumerate (datafile ):
186- kpoints , energy = cls .read (file )
221+ kpoints , energy = self .read (file )
187222 if shift :
188- vb , cb = cls .set_vcband (energy_minus_efermi (energy , efermi [i ]))
223+ vb , cb = self .set_vcband (
224+ energy_minus_efermi (energy , efermi [i ]))
189225 energy_min = np .min (vb .band )
190226 energy_max = np .max (cb .band )
191227 if energy_min > emin :
192228 emin = energy_min
193229 if energy_max < emax :
194230 emax = energy_max
195231
196- ax .plot (kpoints , np .vstack ((vb .band , cb .band )).T ,
197- lw = 0.8 , color = color [i ], label = label [i ])
198- cls .info (kpt .full_kpath , vb , cb )
232+ self . ax .plot (kpoints , np .vstack ((vb .band , cb .band )).T ,
233+ lw = self . _lw , color = self . _color [i ], label = self . _label [ i ], linestyle = self . _linestyle [i ])
234+ self .info (kpt .full_kpath , vb , cb )
199235 else :
200- ax .plot (kpoints , energy_minus_efermi (energy , efermi [i ]),
201- lw = 0.8 , color = color [i ], label = label [i ])
236+ self . ax .plot (kpoints , energy_minus_efermi (energy , efermi [i ]),
237+ lw = self . _lw , color = self . _color [i ], label = self . _label [ i ], linestyle = self . _linestyle [i ])
202238
203239 index = kpt .label_special_k
204- cls ._set_figure (ax , index , energy_range )
205-
206- plt .savefig (outfile )
240+ self ._set_figure (index , energy_range )
207241
208242 @classmethod
209243 def bandgap (cls , vb : namedtuple , cb : namedtuple ):
@@ -254,3 +288,25 @@ def band_type(vbm_x, cbm_x):
254288 for i , j in enumerate (cbm_k ):
255289 if i != 0 :
256290 print (f"{ '' .ljust (30 )} { ' ' .join (list_elem2str (j ))} " , flush = True )
291+
292+
293+ if __name__ == "__main__" :
294+ import matplotlib .pyplot as plt
295+ from pathlib import Path
296+ parent = Path (r"D:\ustc\TEST\HOIP\double HOIP\result\bond" )
297+ name = "CsAgBiBr"
298+ path = parent / name
299+ notes = {'s' : '(b)' }
300+ datafile = [path / "soc.dat" , path / "non-soc.dat" ]
301+ kptfile = path / "KPT"
302+ fig , ax = plt .subplots (figsize = (12 , 12 ))
303+ label = ["with SOC" , "without SOC" ]
304+ color = ["r" , "g" ]
305+ linestyle = ["solid" , "dashed" ]
306+ band = BandPlot (fig , ax , notes = notes , label = label ,
307+ color = color , linestyle = linestyle )
308+ energy_range = [- 5 , 6 ]
309+ efermi = [4.417301755850272 , 4.920435541999894 ]
310+ shift = True
311+ band .multiplot (datafile , kptfile , efermi , energy_range , shift )
312+ fig .savefig ("band.png" )
0 commit comments