Skip to content

Commit cb45297

Browse files
author
jiyuang
committed
modify plot-tools
1 parent 9c6c510 commit cb45297

File tree

7 files changed

+636
-403
lines changed

7 files changed

+636
-403
lines changed

tools/plot-tools/README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<!--
22
* @Date: 2021-08-21 21:58:06
33
* @LastEditors: jiyuyang
4-
* @LastEditTime: 2021-08-26 14:51:04
4+
* @LastEditTime: 2022-01-03 17:21:08
55
66
-->
77

@@ -31,7 +31,7 @@ First, prepare a json file e.g. band-input.json:
3131
```
3232
| Property | Type | Note |
3333
| :------------: | :----------------------: | :------------------------------------------------------------: |
34-
| *filename* | `str` or `List[str]` | Bands data file output from ABACUS |
34+
| *bandfile* | `str` or `List[str]` | Bands data file output from ABACUS |
3535
| *efermi* | `float` or `List[float]` | Fermi level in eV |
3636
| *energy_range* | `list` | Range of energy in eV |
3737
| *shift* | `bool` | If set `'true'`, it will evaluate band gap. Default: `'false'` |
@@ -89,7 +89,6 @@ First, prepare a json file e.g. dos-input.json:
8989
]
9090
},
9191
"pdosfig": "pdos.png",
92-
"tdosfig": "tdos.png"
9392
}
9493
```
9594
If you only want to plot total DOS, you can modify `pdosfile` to `tdosfile` and do not set `species` and `pdosfig`.
@@ -104,7 +103,17 @@ If you only want to plot total DOS, you can modify `pdosfile` to `tdosfile` and
104103
| *tdosfig* | `str` | Output picture of total DOS |
105104
| *pdosfig* | `str` | Output picture of partial DOS |
106105

107-
Then, the following command will plot both total DOS and partial DOS:
106+
Then, the following command will plot total DOS:
108107
```shell
109-
abacus-plot -d dos-input.json
108+
abacus-plot -t tdos-input.json
110109
```
110+
111+
Then, the following command will plot partial DOS:
112+
```shell
113+
abacus-plot -p pdos-input.json
114+
```
115+
116+
Then, the following command will output parsed partial DOS to directory `PDOS_FILE`:
117+
```shell
118+
abacus-plot -o pdos-input.json
119+
```

tools/plot-tools/abacus_plot/band.py

Lines changed: 127 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
'''
2-
Date: 2021-05-08 11:47:09
2+
Date: 2021-12-29 10:27:01
33
LastEditors: jiyuyang
4-
LastEditTime: 2021-08-26 12:07:22
4+
LastEditTime: 2022-01-03 17:06:14
55
66
'''
77

88
from collections import OrderedDict, namedtuple
9+
import numpy as np
910
from os import PathLike
1011
from typing import Sequence, Tuple
11-
12-
import matplotlib.pyplot as plt
13-
import numpy as np
12+
from matplotlib.figure import Figure
1413
from matplotlib import axes
1514

1615
from abacus_plot.utils import energy_minus_efermi, list_elem2str, read_kpt
@@ -19,6 +18,16 @@
1918
class BandPlot:
2019
"""Plot band structure"""
2120

21+
def __init__(self, fig: Figure, ax: axes.Axes, **kwargs) -> None:
22+
self.fig = fig
23+
self.ax = ax
24+
self._lw = kwargs.pop('lw', 2)
25+
self._bwidth = kwargs.pop('bwdith', 3)
26+
self._label = kwargs.pop('label', None)
27+
self._color = kwargs.pop('color', None)
28+
self._linestyle = kwargs.pop('linestyle', 'solid')
29+
self.plot_params = kwargs
30+
2231
@classmethod
2332
def set_vcband(cls, energy: Sequence) -> Tuple[namedtuple, namedtuple]:
2433
"""Separate valence and conduct band
@@ -57,13 +66,12 @@ def read(cls, filename: PathLike) -> Tuple[np.ndarray, np.ndarray]:
5766
data = np.loadtxt(filename)
5867
X, y = np.split(data, (1, ), axis=1)
5968
x = X.flatten()
69+
6070
return x, y
6171

62-
@classmethod
63-
def _set_figure(cls, ax: axes.Axes, index: dict, range: Sequence):
72+
def _set_figure(self, index: dict, range: Sequence):
6473
"""set figure and axes for plotting
6574
66-
:params ax: matplotlib.axes.Axes object
6775
:params index: dict of label of points of x-axis and its index in data file. Range of x-axis based on index.value()
6876
:params range: range of y-axis
6977
"""
@@ -79,131 +87,157 @@ def _set_figure(cls, ax: axes.Axes, index: dict, range: Sequence):
7987
values.append(t)
8088

8189
# x-axis
82-
ax.set_xticks(values)
83-
ax.set_xticklabels(keys)
84-
ax.set_xlim(values[0], values[-1])
85-
ax.set_xlabel("Wave Vector")
90+
self.ax.set_xticks(values)
91+
self.ax.set_xticklabels(keys)
92+
self.ax.set_xlim(values[0], values[-1])
93+
if "xlabel_params" in self.plot_params.keys():
94+
self.ax.set_xlabel(
95+
"Wave Vector", **self.plot_params["xlabel_params"])
96+
else:
97+
self.ax.set_xlabel("Wave Vector", size=25)
8698

8799
# y-axis
88100
if range:
89-
ax.set_ylim(range[0], range[1])
90-
ax.set_ylabel(r"$E-E_{fermi}(eV)$")
101+
self.ax.set_ylim(range[0], range[1])
102+
if "ylabel_params" in self.plot_params.keys():
103+
self.ax.set_ylabel(
104+
"Energy(eV)", **self.plot_params["ylabel_params"])
105+
else:
106+
self.ax.set_ylabel("Energy(eV)", size=25)
107+
108+
# notes
109+
if "notes" in self.plot_params.keys():
110+
from matplotlib.offsetbox import AnchoredText
111+
if "s" in self.plot_params["notes"].keys() and len(self.plot_params["notes"].keys()) == 1:
112+
self.ax.add_artist(AnchoredText(self.plot_params["notes"]["s"], loc='upper left', prop=dict(size=25),
113+
borderpad=0.2, frameon=False))
114+
else:
115+
self.ax.add_artist(AnchoredText(**self.plot_params["notes"]))
91116

92-
# others
93-
ax.grid(axis='x', lw=1.2)
94-
ax.axhline(0, linestyle="--", c='b', lw=1.0)
95-
handles, labels = ax.get_legend_handles_labels()
96-
by_label = OrderedDict(zip(labels, handles))
97-
ax.legend(by_label.values(), by_label.keys())
117+
# ticks
118+
if "tick_params" in self.plot_params.keys():
119+
self.ax.tick_params(**self.plot_params["tick_params"])
120+
else:
121+
self.ax.tick_params(labelsize=25)
122+
123+
# frame
124+
bwidth = self._bwidth
125+
self.ax.spines['top'].set_linewidth(bwidth)
126+
self.ax.spines['right'].set_linewidth(bwidth)
127+
self.ax.spines['left'].set_linewidth(bwidth)
128+
self.ax.spines['bottom'].set_linewidth(bwidth)
129+
130+
# guides
131+
if "grid_params" in self.plot_params.keys():
132+
self.ax.grid(axis='x', **self.plot_params["grid_params"])
133+
else:
134+
self.ax.grid(axis='x', lw=1.2)
135+
if "hline_params" in self.plot_params.keys():
136+
self.ax.axhline(0, **self.plot_params["hline_params"])
137+
else:
138+
self.ax.axhline(0, linestyle="--", c='b', lw=1.0)
139+
140+
if self._label:
141+
handles, labels = self.ax.get_legend_handles_labels()
142+
by_label = OrderedDict(zip(labels, handles))
143+
if "legend_prop" in self.plot_params.keys():
144+
self.ax.legend(by_label.values(), by_label.keys(),
145+
prop=self.plot_params["legend_prop"])
146+
else:
147+
self.ax.legend(by_label.values(),
148+
by_label.keys(), prop={'size': 15})
98149

99-
@classmethod
100-
def plot(cls, x: Sequence, y: Sequence, index: Sequence, efermi: float = 0, energy_range: Sequence[float] = [], label: str = None, color: str = None, outfile: PathLike = 'band.png'):
150+
def plot(self, x: Sequence, y: Sequence, index: Sequence, efermi: float = 0, energy_range: Sequence[float] = []):
101151
"""Plot band structure
102152
103153
:params x, y: x-axis and y-axis coordinates
104154
:params index: special k-points label and its index in data file
105155
:params efermi: Fermi level in unit eV
106156
:params energy_range: range of energy to plot, its length equals to two
107-
:params label: band label. Default: ''
108-
:params color: band color. Default: 'black'
109-
:params outfile: band picture file name. Default: 'band.png'
110157
"""
111158

112-
fig, ax = plt.subplots()
113-
114-
if not color:
115-
color = 'black'
159+
if not self._color:
160+
self._color = 'black'
116161

117162
kpoints, energy = x, y
118163
energy = energy_minus_efermi(energy, efermi)
119164

120-
ax.plot(kpoints, energy, lw=0.8, color=color, label=label)
121-
cls._set_figure(ax, index, energy_range)
165+
self.ax.plot(kpoints, energy, lw=self._lw, color=self._color,
166+
label=self._label, linestyle=self._linestyle)
167+
self._set_figure(index, energy_range)
122168

123-
plt.savefig(outfile)
124-
125-
@classmethod
126-
def singleplot(cls, datafile: PathLike, kptfile: str = [], efermi: float = 0, energy_range: Sequence[float] = [], shift: bool = False, label: str = None, color: str = None, outfile: PathLike = 'band.png'):
169+
def singleplot(self, datafile: PathLike, kptfile: str = '', efermi: float = 0, energy_range: Sequence[float] = [], shift: bool = False):
127170
"""Plot band structure using data file
128171
129172
:params datafile: string of band date file
130173
:params kptfile: k-point file
131174
:params efermi: Fermi level in unit eV
132175
:params energy_range: range of energy to plot, its length equals to two
133176
:params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False
134-
:params label: band label. Default: ''
135-
:params color: band color. Default: 'black'
136-
:params outfile: band picture file name. Default: 'band.png'
137177
"""
138178

139-
fig, ax = plt.subplots()
140179
kpt = read_kpt(kptfile)
141180

142-
if not color:
143-
color = 'black'
181+
if not self._color:
182+
self._color = 'black'
144183

145-
kpoints, energy = cls.read(datafile)
184+
kpoints, energy = self.read(datafile)
185+
energy = energy_minus_efermi(energy, efermi)
146186
if shift:
147-
vb, cb = cls.set_vcband(energy_minus_efermi(energy, efermi))
148-
ax.plot(kpoints, np.vstack((vb.band, cb.band)).T,
149-
lw=0.8, color=color, label=label)
150-
cls.info(kpt.full_kpath, vb, cb)
187+
vb, cb = self.set_vcband(energy)
188+
self.ax.plot(kpoints, np.vstack((vb.band, cb.band)).T,
189+
lw=self._lw, color=self._color, label=self._label, linestyle=self._linestyle)
190+
self.info(kpt.full_kpath, vb, cb)
151191
else:
152-
ax.plot(kpoints, energy_minus_efermi(energy, efermi),
153-
lw=0.8, color=color, label=label)
192+
self.ax.plot(kpoints, energy,
193+
lw=self._lw, color=self._color, label=self._label, linestyle=self._linestyle)
154194
index = kpt.label_special_k
155-
cls._set_figure(ax, index, energy_range)
195+
self._set_figure(index, energy_range)
156196

157-
plt.savefig(outfile)
158-
159-
@classmethod
160-
def multiplot(cls, datafile: Sequence[PathLike], kptfile: str = '', efermi: Sequence[float] = [], energy_range: Sequence[float] = [], shift: bool = True, label: Sequence[str] = None, color: Sequence[str] = None, outfile: PathLike = 'band.png'):
197+
def multiplot(self, datafile: Sequence[PathLike], kptfile: str = '', efermi: Sequence[float] = [], energy_range: Sequence[float] = [], shift: bool = True):
161198
"""Plot more than two band structures using data file
162199
163200
:params datafile: list of path of band date file
164201
:params kptfile: k-point file
165202
:params efermi: list of Fermi levels in unit eV, its length equals to `filename`
166203
:params energy_range: range of energy to plot, its length equals to two
167204
:params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False
168-
:params label: list of band labels, its length equals to `filename`
169-
:params color: list of band colors, its length equals to `filename`
170-
:params outfile: band picture file name. Default: 'band.png'
171205
"""
172206

173-
fig, ax = plt.subplots()
174207
kpt = read_kpt(kptfile)
175208

176209
if not efermi:
177210
efermi = [0.0 for i in range(len(datafile))]
178-
if not label:
179-
label = ['' for i in range(len(datafile))]
180-
if not color:
181-
color = ['black' for i in range(len(datafile))]
211+
if not self._label:
212+
self._label = ['' for i in range(len(datafile))]
213+
if not self._color:
214+
self._color = ['black' for i in range(len(datafile))]
215+
if not self._linestyle:
216+
self._linestyle = ['solid' for i in range(len(datafile))]
182217

183218
emin = -np.inf
184219
emax = np.inf
185220
for i, file in enumerate(datafile):
186-
kpoints, energy = cls.read(file)
221+
kpoints, energy = self.read(file)
187222
if shift:
188-
vb, cb = cls.set_vcband(energy_minus_efermi(energy, efermi[i]))
223+
vb, cb = self.set_vcband(
224+
energy_minus_efermi(energy, efermi[i]))
189225
energy_min = np.min(vb.band)
190226
energy_max = np.max(cb.band)
191227
if energy_min > emin:
192228
emin = energy_min
193229
if energy_max < emax:
194230
emax = energy_max
195231

196-
ax.plot(kpoints, np.vstack((vb.band, cb.band)).T,
197-
lw=0.8, color=color[i], label=label[i])
198-
cls.info(kpt.full_kpath, vb, cb)
232+
self.ax.plot(kpoints, np.vstack((vb.band, cb.band)).T,
233+
lw=self._lw, color=self._color[i], label=self._label[i], linestyle=self._linestyle[i])
234+
self.info(kpt.full_kpath, vb, cb)
199235
else:
200-
ax.plot(kpoints, energy_minus_efermi(energy, efermi[i]),
201-
lw=0.8, color=color[i], label=label[i])
236+
self.ax.plot(kpoints, energy_minus_efermi(energy, efermi[i]),
237+
lw=self._lw, color=self._color[i], label=self._label[i], linestyle=self._linestyle[i])
202238

203239
index = kpt.label_special_k
204-
cls._set_figure(ax, index, energy_range)
205-
206-
plt.savefig(outfile)
240+
self._set_figure(index, energy_range)
207241

208242
@classmethod
209243
def bandgap(cls, vb: namedtuple, cb: namedtuple):
@@ -254,3 +288,25 @@ def band_type(vbm_x, cbm_x):
254288
for i, j in enumerate(cbm_k):
255289
if i != 0:
256290
print(f"{''.ljust(30)}{' '.join(list_elem2str(j))}", flush=True)
291+
292+
293+
if __name__ == "__main__":
294+
import matplotlib.pyplot as plt
295+
from pathlib import Path
296+
parent = Path(r"D:\ustc\TEST\HOIP\double HOIP\result\bond")
297+
name = "CsAgBiBr"
298+
path = parent/name
299+
notes = {'s': '(b)'}
300+
datafile = [path/"soc.dat", path/"non-soc.dat"]
301+
kptfile = path/"KPT"
302+
fig, ax = plt.subplots(figsize=(12, 12))
303+
label = ["with SOC", "without SOC"]
304+
color = ["r", "g"]
305+
linestyle = ["solid", "dashed"]
306+
band = BandPlot(fig, ax, notes=notes, label=label,
307+
color=color, linestyle=linestyle)
308+
energy_range = [-5, 6]
309+
efermi = [4.417301755850272, 4.920435541999894]
310+
shift = True
311+
band.multiplot(datafile, kptfile, efermi, energy_range, shift)
312+
fig.savefig("band.png")

0 commit comments

Comments
 (0)